Fix speaker identification chunk size error for long transcriptions (#1342)

Co-authored-by: Robrecht Siera <rob.developer.securemail@holoncom.eu>
This commit is contained in:
Rob Siera 2026-01-10 10:38:55 +01:00 committed by GitHub
commit f1bc725e2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 145 additions and 1 deletions

View file

@ -14,6 +14,7 @@ if not (platform.system() == "Darwin" and platform.machine() == "x86_64"):
from buzz.widgets.transcription_viewer.speaker_identification_widget import (
SpeakerIdentificationWidget,
IdentificationWorker,
process_in_batches,
)
from tests.audio import test_audio_path
@ -88,3 +89,82 @@ class TestSpeakerIdentificationWidget:
assert isinstance(result[0], list)
assert result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bienvenue dans. '}]]
def test_batch_processing_with_many_words(self):
"""Test batch processing when there are more than 200 words."""
# Create mock punctuation model
mock_punct_model = MagicMock()
mock_punct_model.predict.side_effect = lambda batch, chunk_size: [
(word.strip(), ".") for word in batch
]
# Create words list with 201 words (just enough to trigger batch processing)
words_list = [f"word{i}" for i in range(201)]
# Wrap predict method to match the expected signature
def predict_wrapper(batch, chunk_size, **kwargs):
return mock_punct_model.predict(batch, chunk_size=chunk_size)
# Call the generic batch processing function
result = process_in_batches(
items=words_list,
process_func=predict_wrapper
)
# Verify that predict was called multiple times (for batches)
assert mock_punct_model.predict.call_count >= 2, "Batch processing should split into multiple calls"
# Verify that each batch was processed with correct chunk_size
for call in mock_punct_model.predict.call_args_list:
args, kwargs = call
batch = args[0]
chunk_size = kwargs.get('chunk_size')
assert chunk_size <= 230, "Chunk size should not exceed 230"
assert len(batch) <= 200, "Batch size should not exceed 200"
# Verify result contains all words
assert len(result) == 201, "Result should contain all words"
def test_batch_processing_with_assertion_error_fallback(self):
"""Test error handling when AssertionError occurs during batch processing."""
# Create mock punctuation model - raise AssertionError on first batch, then succeed
mock_punct_model = MagicMock()
call_count = [0]
def predict_side_effect(batch, chunk_size):
call_count[0] += 1
# Raise AssertionError on first call (first batch)
if call_count[0] == 1:
raise AssertionError("Chunk size too large")
# Succeed on subsequent calls (smaller batches)
return [(word.strip(), ".") for word in batch]
mock_punct_model.predict.side_effect = predict_side_effect
# Create words list with 201 words (enough to trigger batch processing)
words_list = [f"word{i}" for i in range(201)]
# Wrap predict method to match the expected signature
def predict_wrapper(batch, chunk_size, **kwargs):
return mock_punct_model.predict(batch, chunk_size=chunk_size)
# Call the generic batch processing function
result = process_in_batches(
items=words_list,
process_func=predict_wrapper
)
# Verify that predict was called multiple times
# First call fails, then smaller batches succeed
assert mock_punct_model.predict.call_count > 1, "Should retry with smaller batches after AssertionError"
# Verify that smaller batches were used after the error
call_args_list = mock_punct_model.predict.call_args_list
# After the first failed call, subsequent calls should have smaller batches
for i, call in enumerate(call_args_list[1:], start=1): # Skip first failed call
args, kwargs = call
batch = args[0]
assert len(batch) <= 100, f"After AssertionError, batch size should be <= 100, got {len(batch)}"
# Verify result contains all words
assert len(result) == 201, "Result should contain all words"