From f1bc725e2b26e41ddedbb440b35cb27b3351fb5b Mon Sep 17 00:00:00 2001 From: Rob Siera Date: Sat, 10 Jan 2026 10:38:55 +0100 Subject: [PATCH] Fix speaker identification chunk size error for long transcriptions (#1342) Co-authored-by: Robrecht Siera --- .../speaker_identification_widget.py | 66 ++++++++++++++- .../speaker_identification_widget_test.py | 80 +++++++++++++++++++ 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/buzz/widgets/transcription_viewer/speaker_identification_widget.py b/buzz/widgets/transcription_viewer/speaker_identification_widget.py index 6ea6eec1..cc794419 100644 --- a/buzz/widgets/transcription_viewer/speaker_identification_widget.py +++ b/buzz/widgets/transcription_viewer/speaker_identification_widget.py @@ -62,6 +62,63 @@ from whisper_diarization.helpers import ( from deepmultilingualpunctuation.deepmultilingualpunctuation import PunctuationModel from whisper_diarization.diarization import MSDDDiarizer + +def process_in_batches( + items, + process_func, + batch_size=200, + chunk_size=230, + smaller_batch_size=100, + exception_types=(AssertionError,), + **process_func_kwargs +): + """ + Process items in batches with automatic fallback to smaller batches on errors. + + This is a generic batch processing function that can be used with any processing + function that has chunk size limitations. It automatically retries with smaller + batches when specified exceptions occur. + + Args: + items: List of items to process + process_func: Callable that processes a batch. Should accept (batch, chunk_size, **kwargs) + and return a list of results + batch_size: Initial batch size (default: 200) + chunk_size: Maximum chunk size for the processing function (default: 230) + smaller_batch_size: Fallback batch size when errors occur (default: 100) + exception_types: Tuple of exception types to catch and retry with smaller batches + (default: (AssertionError,)) + **process_func_kwargs: Additional keyword arguments to pass to process_func + + Returns: + List of processed results (concatenated from all batches) + + Example: + >>> def my_predict(batch, chunk_size): + ... return [f"processed_{item}" for item in batch] + >>> results = process_in_batches( + ... items=["a", "b", "c"], + ... process_func=my_predict, + ... batch_size=2 + ... ) + """ + all_results = [] + + for i in range(0, len(items), batch_size): + batch = items[i:i + batch_size] + try: + batch_results = process_func(batch, chunk_size=min(chunk_size, len(batch)), **process_func_kwargs) + all_results.extend(batch_results) + except exception_types as e: + # If batch still fails, try with even smaller chunks + logging.warning(f"Batch processing failed, trying smaller chunks: {e}") + for j in range(0, len(batch), smaller_batch_size): + smaller_batch = batch[j:j + smaller_batch_size] + smaller_results = process_func(smaller_batch, chunk_size=min(chunk_size, len(smaller_batch)), **process_func_kwargs) + all_results.extend(smaller_results) + + return all_results + SENTENCE_END = re.compile(r'.*[.!?。!?]') class IdentificationWorker(QObject): @@ -267,7 +324,14 @@ class IdentificationWorker(QObject): words_list = list(map(lambda x: x["word"], wsm)) - labled_words = punct_model.predict(words_list, chunk_size=230) + # Process in batches to avoid chunk size errors + def predict_wrapper(batch, chunk_size, **kwargs): + return punct_model.predict(batch, chunk_size=chunk_size) + + labled_words = process_in_batches( + items=words_list, + process_func=predict_wrapper + ) ending_puncts = ".?!。!?" model_puncts = ".,;:!?。!?" diff --git a/tests/widgets/speaker_identification_widget_test.py b/tests/widgets/speaker_identification_widget_test.py index 946948dc..5f10e6ce 100644 --- a/tests/widgets/speaker_identification_widget_test.py +++ b/tests/widgets/speaker_identification_widget_test.py @@ -14,6 +14,7 @@ if not (platform.system() == "Darwin" and platform.machine() == "x86_64"): from buzz.widgets.transcription_viewer.speaker_identification_widget import ( SpeakerIdentificationWidget, IdentificationWorker, + process_in_batches, ) from tests.audio import test_audio_path @@ -88,3 +89,82 @@ class TestSpeakerIdentificationWidget: assert isinstance(result[0], list) assert result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bienvenue dans. '}]] + def test_batch_processing_with_many_words(self): + """Test batch processing when there are more than 200 words.""" + # Create mock punctuation model + mock_punct_model = MagicMock() + mock_punct_model.predict.side_effect = lambda batch, chunk_size: [ + (word.strip(), ".") for word in batch + ] + + # Create words list with 201 words (just enough to trigger batch processing) + words_list = [f"word{i}" for i in range(201)] + + # Wrap predict method to match the expected signature + def predict_wrapper(batch, chunk_size, **kwargs): + return mock_punct_model.predict(batch, chunk_size=chunk_size) + + # Call the generic batch processing function + result = process_in_batches( + items=words_list, + process_func=predict_wrapper + ) + + # Verify that predict was called multiple times (for batches) + assert mock_punct_model.predict.call_count >= 2, "Batch processing should split into multiple calls" + + # Verify that each batch was processed with correct chunk_size + for call in mock_punct_model.predict.call_args_list: + args, kwargs = call + batch = args[0] + chunk_size = kwargs.get('chunk_size') + assert chunk_size <= 230, "Chunk size should not exceed 230" + assert len(batch) <= 200, "Batch size should not exceed 200" + + # Verify result contains all words + assert len(result) == 201, "Result should contain all words" + + def test_batch_processing_with_assertion_error_fallback(self): + """Test error handling when AssertionError occurs during batch processing.""" + # Create mock punctuation model - raise AssertionError on first batch, then succeed + mock_punct_model = MagicMock() + call_count = [0] + + def predict_side_effect(batch, chunk_size): + call_count[0] += 1 + # Raise AssertionError on first call (first batch) + if call_count[0] == 1: + raise AssertionError("Chunk size too large") + # Succeed on subsequent calls (smaller batches) + return [(word.strip(), ".") for word in batch] + + mock_punct_model.predict.side_effect = predict_side_effect + + # Create words list with 201 words (enough to trigger batch processing) + words_list = [f"word{i}" for i in range(201)] + + # Wrap predict method to match the expected signature + def predict_wrapper(batch, chunk_size, **kwargs): + return mock_punct_model.predict(batch, chunk_size=chunk_size) + + # Call the generic batch processing function + result = process_in_batches( + items=words_list, + process_func=predict_wrapper + ) + + # Verify that predict was called multiple times + # First call fails, then smaller batches succeed + assert mock_punct_model.predict.call_count > 1, "Should retry with smaller batches after AssertionError" + + # Verify that smaller batches were used after the error + call_args_list = mock_punct_model.predict.call_args_list + # After the first failed call, subsequent calls should have smaller batches + for i, call in enumerate(call_args_list[1:], start=1): # Skip first failed call + args, kwargs = call + batch = args[0] + assert len(batch) <= 100, f"After AssertionError, batch size should be <= 100, got {len(batch)}" + + # Verify result contains all words + assert len(result) == 201, "Result should contain all words" +