Fix speaker identification chunk size error for long transcriptions (#1342)

Co-authored-by: Robrecht Siera <rob.developer.securemail@holoncom.eu>
This commit is contained in:
Rob Siera 2026-01-10 10:38:55 +01:00 committed by GitHub
commit f1bc725e2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 145 additions and 1 deletions

View file

@ -62,6 +62,63 @@ from whisper_diarization.helpers import (
from deepmultilingualpunctuation.deepmultilingualpunctuation import PunctuationModel
from whisper_diarization.diarization import MSDDDiarizer
def process_in_batches(
items,
process_func,
batch_size=200,
chunk_size=230,
smaller_batch_size=100,
exception_types=(AssertionError,),
**process_func_kwargs
):
"""
Process items in batches with automatic fallback to smaller batches on errors.
This is a generic batch processing function that can be used with any processing
function that has chunk size limitations. It automatically retries with smaller
batches when specified exceptions occur.
Args:
items: List of items to process
process_func: Callable that processes a batch. Should accept (batch, chunk_size, **kwargs)
and return a list of results
batch_size: Initial batch size (default: 200)
chunk_size: Maximum chunk size for the processing function (default: 230)
smaller_batch_size: Fallback batch size when errors occur (default: 100)
exception_types: Tuple of exception types to catch and retry with smaller batches
(default: (AssertionError,))
**process_func_kwargs: Additional keyword arguments to pass to process_func
Returns:
List of processed results (concatenated from all batches)
Example:
>>> def my_predict(batch, chunk_size):
... return [f"processed_{item}" for item in batch]
>>> results = process_in_batches(
... items=["a", "b", "c"],
... process_func=my_predict,
... batch_size=2
... )
"""
all_results = []
for i in range(0, len(items), batch_size):
batch = items[i:i + batch_size]
try:
batch_results = process_func(batch, chunk_size=min(chunk_size, len(batch)), **process_func_kwargs)
all_results.extend(batch_results)
except exception_types as e:
# If batch still fails, try with even smaller chunks
logging.warning(f"Batch processing failed, trying smaller chunks: {e}")
for j in range(0, len(batch), smaller_batch_size):
smaller_batch = batch[j:j + smaller_batch_size]
smaller_results = process_func(smaller_batch, chunk_size=min(chunk_size, len(smaller_batch)), **process_func_kwargs)
all_results.extend(smaller_results)
return all_results
SENTENCE_END = re.compile(r'.*[.!?。!?]')
class IdentificationWorker(QObject):
@ -267,7 +324,14 @@ class IdentificationWorker(QObject):
words_list = list(map(lambda x: x["word"], wsm))
labled_words = punct_model.predict(words_list, chunk_size=230)
# Process in batches to avoid chunk size errors
def predict_wrapper(batch, chunk_size, **kwargs):
return punct_model.predict(batch, chunk_size=chunk_size)
labled_words = process_in_batches(
items=words_list,
process_func=predict_wrapper
)
ending_puncts = ".?!。!?"
model_puncts = ".,;:!?。!?"

View file

@ -14,6 +14,7 @@ if not (platform.system() == "Darwin" and platform.machine() == "x86_64"):
from buzz.widgets.transcription_viewer.speaker_identification_widget import (
SpeakerIdentificationWidget,
IdentificationWorker,
process_in_batches,
)
from tests.audio import test_audio_path
@ -88,3 +89,82 @@ class TestSpeakerIdentificationWidget:
assert isinstance(result[0], list)
assert result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bienvenue dans. '}]]
def test_batch_processing_with_many_words(self):
"""Test batch processing when there are more than 200 words."""
# Create mock punctuation model
mock_punct_model = MagicMock()
mock_punct_model.predict.side_effect = lambda batch, chunk_size: [
(word.strip(), ".") for word in batch
]
# Create words list with 201 words (just enough to trigger batch processing)
words_list = [f"word{i}" for i in range(201)]
# Wrap predict method to match the expected signature
def predict_wrapper(batch, chunk_size, **kwargs):
return mock_punct_model.predict(batch, chunk_size=chunk_size)
# Call the generic batch processing function
result = process_in_batches(
items=words_list,
process_func=predict_wrapper
)
# Verify that predict was called multiple times (for batches)
assert mock_punct_model.predict.call_count >= 2, "Batch processing should split into multiple calls"
# Verify that each batch was processed with correct chunk_size
for call in mock_punct_model.predict.call_args_list:
args, kwargs = call
batch = args[0]
chunk_size = kwargs.get('chunk_size')
assert chunk_size <= 230, "Chunk size should not exceed 230"
assert len(batch) <= 200, "Batch size should not exceed 200"
# Verify result contains all words
assert len(result) == 201, "Result should contain all words"
def test_batch_processing_with_assertion_error_fallback(self):
"""Test error handling when AssertionError occurs during batch processing."""
# Create mock punctuation model - raise AssertionError on first batch, then succeed
mock_punct_model = MagicMock()
call_count = [0]
def predict_side_effect(batch, chunk_size):
call_count[0] += 1
# Raise AssertionError on first call (first batch)
if call_count[0] == 1:
raise AssertionError("Chunk size too large")
# Succeed on subsequent calls (smaller batches)
return [(word.strip(), ".") for word in batch]
mock_punct_model.predict.side_effect = predict_side_effect
# Create words list with 201 words (enough to trigger batch processing)
words_list = [f"word{i}" for i in range(201)]
# Wrap predict method to match the expected signature
def predict_wrapper(batch, chunk_size, **kwargs):
return mock_punct_model.predict(batch, chunk_size=chunk_size)
# Call the generic batch processing function
result = process_in_batches(
items=words_list,
process_func=predict_wrapper
)
# Verify that predict was called multiple times
# First call fails, then smaller batches succeed
assert mock_punct_model.predict.call_count > 1, "Should retry with smaller batches after AssertionError"
# Verify that smaller batches were used after the error
call_args_list = mock_punct_model.predict.call_args_list
# After the first failed call, subsequent calls should have smaller batches
for i, call in enumerate(call_args_list[1:], start=1): # Skip first failed call
args, kwargs = call
batch = args[0]
assert len(batch) <= 100, f"After AssertionError, batch size should be <= 100, got {len(batch)}"
# Verify result contains all words
assert len(result) == 201, "Result should contain all words"