mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-18 00:19:57 +01:00
Fix speaker identification chunk size error for long transcriptions (#1342)
Co-authored-by: Robrecht Siera <rob.developer.securemail@holoncom.eu>
This commit is contained in:
parent
43214f5c3d
commit
f1bc725e2b
2 changed files with 145 additions and 1 deletions
|
|
@ -62,6 +62,63 @@ from whisper_diarization.helpers import (
|
|||
from deepmultilingualpunctuation.deepmultilingualpunctuation import PunctuationModel
|
||||
from whisper_diarization.diarization import MSDDDiarizer
|
||||
|
||||
|
||||
def process_in_batches(
|
||||
items,
|
||||
process_func,
|
||||
batch_size=200,
|
||||
chunk_size=230,
|
||||
smaller_batch_size=100,
|
||||
exception_types=(AssertionError,),
|
||||
**process_func_kwargs
|
||||
):
|
||||
"""
|
||||
Process items in batches with automatic fallback to smaller batches on errors.
|
||||
|
||||
This is a generic batch processing function that can be used with any processing
|
||||
function that has chunk size limitations. It automatically retries with smaller
|
||||
batches when specified exceptions occur.
|
||||
|
||||
Args:
|
||||
items: List of items to process
|
||||
process_func: Callable that processes a batch. Should accept (batch, chunk_size, **kwargs)
|
||||
and return a list of results
|
||||
batch_size: Initial batch size (default: 200)
|
||||
chunk_size: Maximum chunk size for the processing function (default: 230)
|
||||
smaller_batch_size: Fallback batch size when errors occur (default: 100)
|
||||
exception_types: Tuple of exception types to catch and retry with smaller batches
|
||||
(default: (AssertionError,))
|
||||
**process_func_kwargs: Additional keyword arguments to pass to process_func
|
||||
|
||||
Returns:
|
||||
List of processed results (concatenated from all batches)
|
||||
|
||||
Example:
|
||||
>>> def my_predict(batch, chunk_size):
|
||||
... return [f"processed_{item}" for item in batch]
|
||||
>>> results = process_in_batches(
|
||||
... items=["a", "b", "c"],
|
||||
... process_func=my_predict,
|
||||
... batch_size=2
|
||||
... )
|
||||
"""
|
||||
all_results = []
|
||||
|
||||
for i in range(0, len(items), batch_size):
|
||||
batch = items[i:i + batch_size]
|
||||
try:
|
||||
batch_results = process_func(batch, chunk_size=min(chunk_size, len(batch)), **process_func_kwargs)
|
||||
all_results.extend(batch_results)
|
||||
except exception_types as e:
|
||||
# If batch still fails, try with even smaller chunks
|
||||
logging.warning(f"Batch processing failed, trying smaller chunks: {e}")
|
||||
for j in range(0, len(batch), smaller_batch_size):
|
||||
smaller_batch = batch[j:j + smaller_batch_size]
|
||||
smaller_results = process_func(smaller_batch, chunk_size=min(chunk_size, len(smaller_batch)), **process_func_kwargs)
|
||||
all_results.extend(smaller_results)
|
||||
|
||||
return all_results
|
||||
|
||||
SENTENCE_END = re.compile(r'.*[.!?。!?]')
|
||||
|
||||
class IdentificationWorker(QObject):
|
||||
|
|
@ -267,7 +324,14 @@ class IdentificationWorker(QObject):
|
|||
|
||||
words_list = list(map(lambda x: x["word"], wsm))
|
||||
|
||||
labled_words = punct_model.predict(words_list, chunk_size=230)
|
||||
# Process in batches to avoid chunk size errors
|
||||
def predict_wrapper(batch, chunk_size, **kwargs):
|
||||
return punct_model.predict(batch, chunk_size=chunk_size)
|
||||
|
||||
labled_words = process_in_batches(
|
||||
items=words_list,
|
||||
process_func=predict_wrapper
|
||||
)
|
||||
|
||||
ending_puncts = ".?!。!?"
|
||||
model_puncts = ".,;:!?。!?"
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ if not (platform.system() == "Darwin" and platform.machine() == "x86_64"):
|
|||
from buzz.widgets.transcription_viewer.speaker_identification_widget import (
|
||||
SpeakerIdentificationWidget,
|
||||
IdentificationWorker,
|
||||
process_in_batches,
|
||||
)
|
||||
from tests.audio import test_audio_path
|
||||
|
||||
|
|
@ -88,3 +89,82 @@ class TestSpeakerIdentificationWidget:
|
|||
assert isinstance(result[0], list)
|
||||
assert result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bienvenue dans. '}]]
|
||||
|
||||
def test_batch_processing_with_many_words(self):
|
||||
"""Test batch processing when there are more than 200 words."""
|
||||
# Create mock punctuation model
|
||||
mock_punct_model = MagicMock()
|
||||
mock_punct_model.predict.side_effect = lambda batch, chunk_size: [
|
||||
(word.strip(), ".") for word in batch
|
||||
]
|
||||
|
||||
# Create words list with 201 words (just enough to trigger batch processing)
|
||||
words_list = [f"word{i}" for i in range(201)]
|
||||
|
||||
# Wrap predict method to match the expected signature
|
||||
def predict_wrapper(batch, chunk_size, **kwargs):
|
||||
return mock_punct_model.predict(batch, chunk_size=chunk_size)
|
||||
|
||||
# Call the generic batch processing function
|
||||
result = process_in_batches(
|
||||
items=words_list,
|
||||
process_func=predict_wrapper
|
||||
)
|
||||
|
||||
# Verify that predict was called multiple times (for batches)
|
||||
assert mock_punct_model.predict.call_count >= 2, "Batch processing should split into multiple calls"
|
||||
|
||||
# Verify that each batch was processed with correct chunk_size
|
||||
for call in mock_punct_model.predict.call_args_list:
|
||||
args, kwargs = call
|
||||
batch = args[0]
|
||||
chunk_size = kwargs.get('chunk_size')
|
||||
assert chunk_size <= 230, "Chunk size should not exceed 230"
|
||||
assert len(batch) <= 200, "Batch size should not exceed 200"
|
||||
|
||||
# Verify result contains all words
|
||||
assert len(result) == 201, "Result should contain all words"
|
||||
|
||||
def test_batch_processing_with_assertion_error_fallback(self):
|
||||
"""Test error handling when AssertionError occurs during batch processing."""
|
||||
# Create mock punctuation model - raise AssertionError on first batch, then succeed
|
||||
mock_punct_model = MagicMock()
|
||||
call_count = [0]
|
||||
|
||||
def predict_side_effect(batch, chunk_size):
|
||||
call_count[0] += 1
|
||||
# Raise AssertionError on first call (first batch)
|
||||
if call_count[0] == 1:
|
||||
raise AssertionError("Chunk size too large")
|
||||
# Succeed on subsequent calls (smaller batches)
|
||||
return [(word.strip(), ".") for word in batch]
|
||||
|
||||
mock_punct_model.predict.side_effect = predict_side_effect
|
||||
|
||||
# Create words list with 201 words (enough to trigger batch processing)
|
||||
words_list = [f"word{i}" for i in range(201)]
|
||||
|
||||
# Wrap predict method to match the expected signature
|
||||
def predict_wrapper(batch, chunk_size, **kwargs):
|
||||
return mock_punct_model.predict(batch, chunk_size=chunk_size)
|
||||
|
||||
# Call the generic batch processing function
|
||||
result = process_in_batches(
|
||||
items=words_list,
|
||||
process_func=predict_wrapper
|
||||
)
|
||||
|
||||
# Verify that predict was called multiple times
|
||||
# First call fails, then smaller batches succeed
|
||||
assert mock_punct_model.predict.call_count > 1, "Should retry with smaller batches after AssertionError"
|
||||
|
||||
# Verify that smaller batches were used after the error
|
||||
call_args_list = mock_punct_model.predict.call_args_list
|
||||
# After the first failed call, subsequent calls should have smaller batches
|
||||
for i, call in enumerate(call_args_list[1:], start=1): # Skip first failed call
|
||||
args, kwargs = call
|
||||
batch = args[0]
|
||||
assert len(batch) <= 100, f"After AssertionError, batch size should be <= 100, got {len(batch)}"
|
||||
|
||||
# Verify result contains all words
|
||||
assert len(result) == 201, "Result should contain all words"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue