mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 14:45:46 +01:00
Fix speaker identification chunk size error for long transcriptions (#1342)
Co-authored-by: Robrecht Siera <rob.developer.securemail@holoncom.eu>
This commit is contained in:
parent
43214f5c3d
commit
f1bc725e2b
2 changed files with 145 additions and 1 deletions
|
|
@ -14,6 +14,7 @@ if not (platform.system() == "Darwin" and platform.machine() == "x86_64"):
|
|||
from buzz.widgets.transcription_viewer.speaker_identification_widget import (
|
||||
SpeakerIdentificationWidget,
|
||||
IdentificationWorker,
|
||||
process_in_batches,
|
||||
)
|
||||
from tests.audio import test_audio_path
|
||||
|
||||
|
|
@ -88,3 +89,82 @@ class TestSpeakerIdentificationWidget:
|
|||
assert isinstance(result[0], list)
|
||||
assert result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bienvenue dans. '}]]
|
||||
|
||||
def test_batch_processing_with_many_words(self):
|
||||
"""Test batch processing when there are more than 200 words."""
|
||||
# Create mock punctuation model
|
||||
mock_punct_model = MagicMock()
|
||||
mock_punct_model.predict.side_effect = lambda batch, chunk_size: [
|
||||
(word.strip(), ".") for word in batch
|
||||
]
|
||||
|
||||
# Create words list with 201 words (just enough to trigger batch processing)
|
||||
words_list = [f"word{i}" for i in range(201)]
|
||||
|
||||
# Wrap predict method to match the expected signature
|
||||
def predict_wrapper(batch, chunk_size, **kwargs):
|
||||
return mock_punct_model.predict(batch, chunk_size=chunk_size)
|
||||
|
||||
# Call the generic batch processing function
|
||||
result = process_in_batches(
|
||||
items=words_list,
|
||||
process_func=predict_wrapper
|
||||
)
|
||||
|
||||
# Verify that predict was called multiple times (for batches)
|
||||
assert mock_punct_model.predict.call_count >= 2, "Batch processing should split into multiple calls"
|
||||
|
||||
# Verify that each batch was processed with correct chunk_size
|
||||
for call in mock_punct_model.predict.call_args_list:
|
||||
args, kwargs = call
|
||||
batch = args[0]
|
||||
chunk_size = kwargs.get('chunk_size')
|
||||
assert chunk_size <= 230, "Chunk size should not exceed 230"
|
||||
assert len(batch) <= 200, "Batch size should not exceed 200"
|
||||
|
||||
# Verify result contains all words
|
||||
assert len(result) == 201, "Result should contain all words"
|
||||
|
||||
def test_batch_processing_with_assertion_error_fallback(self):
|
||||
"""Test error handling when AssertionError occurs during batch processing."""
|
||||
# Create mock punctuation model - raise AssertionError on first batch, then succeed
|
||||
mock_punct_model = MagicMock()
|
||||
call_count = [0]
|
||||
|
||||
def predict_side_effect(batch, chunk_size):
|
||||
call_count[0] += 1
|
||||
# Raise AssertionError on first call (first batch)
|
||||
if call_count[0] == 1:
|
||||
raise AssertionError("Chunk size too large")
|
||||
# Succeed on subsequent calls (smaller batches)
|
||||
return [(word.strip(), ".") for word in batch]
|
||||
|
||||
mock_punct_model.predict.side_effect = predict_side_effect
|
||||
|
||||
# Create words list with 201 words (enough to trigger batch processing)
|
||||
words_list = [f"word{i}" for i in range(201)]
|
||||
|
||||
# Wrap predict method to match the expected signature
|
||||
def predict_wrapper(batch, chunk_size, **kwargs):
|
||||
return mock_punct_model.predict(batch, chunk_size=chunk_size)
|
||||
|
||||
# Call the generic batch processing function
|
||||
result = process_in_batches(
|
||||
items=words_list,
|
||||
process_func=predict_wrapper
|
||||
)
|
||||
|
||||
# Verify that predict was called multiple times
|
||||
# First call fails, then smaller batches succeed
|
||||
assert mock_punct_model.predict.call_count > 1, "Should retry with smaller batches after AssertionError"
|
||||
|
||||
# Verify that smaller batches were used after the error
|
||||
call_args_list = mock_punct_model.predict.call_args_list
|
||||
# After the first failed call, subsequent calls should have smaller batches
|
||||
for i, call in enumerate(call_args_list[1:], start=1): # Skip first failed call
|
||||
args, kwargs = call
|
||||
batch = args[0]
|
||||
assert len(batch) <= 100, f"After AssertionError, batch size should be <= 100, got {len(batch)}"
|
||||
|
||||
# Verify result contains all words
|
||||
assert len(result) == 201, "Result should contain all words"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue