mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 14:45:46 +01:00
382 lines
16 KiB
Python
382 lines
16 KiB
Python
import logging
|
|
import platform
|
|
import time
|
|
import uuid
|
|
import pytest
|
|
from pytestqt.qtbot import QtBot
|
|
from unittest.mock import MagicMock, patch
|
|
from buzz.db.entity.transcription import Transcription
|
|
from buzz.db.entity.transcription_segment import TranscriptionSegment
|
|
from buzz.model_loader import ModelType, WhisperModelSize
|
|
from buzz.transcriber.transcriber import Task
|
|
# Underlying libs do not support intel Macs or Windows (nemo C extensions crash on Windows CI)
|
|
if not (platform.system() == "Darwin" and platform.machine() == "x86_64") and platform.system() != "Windows":
|
|
from buzz.widgets.transcription_viewer.speaker_identification_widget import (
|
|
SpeakerIdentificationWidget,
|
|
IdentificationWorker,
|
|
process_in_batches,
|
|
)
|
|
from tests.audio import test_audio_path
|
|
|
|
@pytest.mark.skipif(
|
|
(platform.system() == "Darwin" and platform.machine() == "x86_64") or platform.system() == "Windows",
|
|
reason="Speaker identification dependencies (nemo/texterrors C extensions) crash on Windows and are unsupported on Intel Mac"
|
|
)
|
|
class TestSpeakerIdentificationWidget:
|
|
@pytest.fixture()
|
|
def transcription(
|
|
self, transcription_dao, transcription_segment_dao
|
|
) -> Transcription:
|
|
id = uuid.uuid4()
|
|
transcription_dao.insert(
|
|
Transcription(
|
|
id=str(id),
|
|
status="completed",
|
|
file=test_audio_path,
|
|
task=Task.TRANSCRIBE.value,
|
|
model_type=ModelType.WHISPER.value,
|
|
whisper_model_size=WhisperModelSize.SMALL.value,
|
|
)
|
|
)
|
|
transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", "", str(id)))
|
|
transcription_segment_dao.insert(
|
|
TranscriptionSegment(299, 329, "venue dans", "", str(id))
|
|
)
|
|
|
|
return transcription_dao.find_by_id(str(id))
|
|
|
|
def test_widget_initialization(self, qtbot: QtBot, transcription, transcription_service):
|
|
"""Test the initialization of SpeakerIdentificationWidget."""
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
assert widget.transcription == transcription
|
|
assert widget.transcription_service == transcription_service
|
|
assert widget.progress_bar.value() == 0
|
|
|
|
widget.close()
|
|
|
|
# Wait to clean-up threads
|
|
time.sleep(3)
|
|
|
|
@pytest.mark.skipif(
|
|
platform.system() == "Linux",
|
|
reason="Skip speaker identification worker test on Linux, CI freezes"
|
|
)
|
|
@patch("buzz.widgets.transcription_viewer.speaker_identification_widget.IdentificationWorker")
|
|
def test_identification_worker_run(self, qtbot: QtBot, transcription, transcription_service):
|
|
"""Test the IdentificationWorker's run method and capture the finished signal result."""
|
|
worker = IdentificationWorker(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
|
|
result = []
|
|
|
|
def capture_result(data):
|
|
result.append(data)
|
|
|
|
worker.finished.connect(capture_result)
|
|
|
|
with qtbot.waitSignal(worker.finished, timeout= 300000): #5 min timeout
|
|
worker.run()
|
|
|
|
assert worker.transcription == transcription
|
|
assert len(result) == 1
|
|
assert isinstance(result[0], list)
|
|
assert (result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bien venue dans. '}]]
|
|
or result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bienvenue dans. '}]])
|
|
|
|
def test_identify_button_toggles_visibility(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
# Before: identify visible, cancel hidden
|
|
assert not widget.step_1_button.isHidden()
|
|
assert widget.cancel_button.isHidden()
|
|
|
|
from PyQt6.QtCore import QThread as RealQThread
|
|
mock_thread = MagicMock(spec=RealQThread)
|
|
mock_thread.started = MagicMock()
|
|
mock_thread.started.connect = MagicMock()
|
|
|
|
with patch.object(widget, '_cleanup_thread'), \
|
|
patch('buzz.widgets.transcription_viewer.speaker_identification_widget.QThread', return_value=mock_thread), \
|
|
patch.object(widget, 'worker', create=True):
|
|
# patch moveToThread on IdentificationWorker to avoid type error
|
|
with patch.object(IdentificationWorker, 'moveToThread'):
|
|
widget.on_identify_button_clicked()
|
|
|
|
# After: identify hidden, cancel visible
|
|
assert widget.step_1_button.isHidden()
|
|
assert not widget.cancel_button.isHidden()
|
|
|
|
widget.close()
|
|
|
|
def test_cancel_button_resets_ui(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
# Simulate identification started
|
|
widget.step_1_button.setVisible(False)
|
|
widget.cancel_button.setVisible(True)
|
|
|
|
with patch.object(widget, '_cleanup_thread'):
|
|
widget.on_cancel_button_clicked()
|
|
|
|
assert not widget.step_1_button.isHidden()
|
|
assert widget.cancel_button.isHidden()
|
|
assert widget.progress_bar.value() == 0
|
|
assert len(widget.progress_label.text()) > 0
|
|
|
|
widget.close()
|
|
|
|
def test_on_progress_update_sets_label_and_bar(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
widget.on_progress_update("3/8 Loading alignment model")
|
|
|
|
assert widget.progress_label.text() == "3/8 Loading alignment model"
|
|
assert widget.progress_bar.value() == 3
|
|
|
|
widget.close()
|
|
|
|
def test_on_progress_update_step_8_enables_save(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
assert not widget.save_button.isEnabled()
|
|
|
|
widget.on_progress_update("8/8 Identification done")
|
|
|
|
assert widget.save_button.isEnabled()
|
|
assert widget.step_2_group_box.isEnabled()
|
|
assert widget.merge_speaker_sentences.isEnabled()
|
|
|
|
widget.close()
|
|
|
|
def test_on_identification_finished_empty_result(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
initial_row_count = widget.speaker_preview_row.count()
|
|
|
|
widget.on_identification_finished([])
|
|
|
|
assert widget.identification_result == []
|
|
# Empty result returns early — speaker preview row unchanged
|
|
assert widget.speaker_preview_row.count() == initial_row_count
|
|
|
|
widget.close()
|
|
|
|
def test_on_identification_finished_populates_speakers(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
result = [
|
|
{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 3000, 'text': 'Hello world.'},
|
|
{'speaker': 'Speaker 1', 'start_time': 3000, 'end_time': 6000, 'text': 'Hi there.'},
|
|
]
|
|
widget.on_identification_finished(result)
|
|
|
|
assert widget.identification_result == result
|
|
# Two speaker rows should have been created
|
|
assert widget.speaker_preview_row.count() == 2
|
|
|
|
widget.close()
|
|
|
|
def test_on_identification_error_resets_buttons(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
widget.step_1_button.setVisible(False)
|
|
widget.cancel_button.setVisible(True)
|
|
|
|
widget.on_identification_error("Some error")
|
|
|
|
assert not widget.step_1_button.isHidden()
|
|
assert widget.cancel_button.isHidden()
|
|
assert widget.progress_bar.value() == 0
|
|
|
|
widget.close()
|
|
|
|
def test_on_save_no_merge(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
result = [
|
|
{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 2000, 'text': 'Hello.'},
|
|
{'speaker': 'Speaker 0', 'start_time': 2000, 'end_time': 4000, 'text': 'World.'},
|
|
{'speaker': 'Speaker 1', 'start_time': 4000, 'end_time': 6000, 'text': 'Hi.'},
|
|
]
|
|
widget.on_identification_finished(result)
|
|
widget.merge_speaker_sentences.setChecked(False)
|
|
|
|
with patch.object(widget.transcription_service, 'copy_transcription', return_value=uuid.uuid4()) as mock_copy, \
|
|
patch.object(widget.transcription_service, 'update_transcription_as_completed') as mock_update:
|
|
widget.on_save_button_clicked()
|
|
|
|
mock_copy.assert_called_once()
|
|
mock_update.assert_called_once()
|
|
segments = mock_update.call_args[0][1]
|
|
# No merge: 3 entries → 3 segments
|
|
assert len(segments) == 3
|
|
|
|
widget.close()
|
|
|
|
def test_on_save_with_merge(self, qtbot: QtBot, transcription, transcription_service):
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
result = [
|
|
{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 2000, 'text': 'Hello.'},
|
|
{'speaker': 'Speaker 0', 'start_time': 2000, 'end_time': 4000, 'text': 'World.'},
|
|
{'speaker': 'Speaker 1', 'start_time': 4000, 'end_time': 6000, 'text': 'Hi.'},
|
|
]
|
|
widget.on_identification_finished(result)
|
|
widget.merge_speaker_sentences.setChecked(True)
|
|
|
|
with patch.object(widget.transcription_service, 'copy_transcription', return_value=uuid.uuid4()), \
|
|
patch.object(widget.transcription_service, 'update_transcription_as_completed') as mock_update:
|
|
widget.on_save_button_clicked()
|
|
|
|
segments = mock_update.call_args[0][1]
|
|
# Merge: two consecutive Speaker 0 entries → merged into 1; Speaker 1 → 1 = 2 total
|
|
assert len(segments) == 2
|
|
assert "Speaker 0" in segments[0].text
|
|
assert "Hello." in segments[0].text
|
|
assert "World." in segments[0].text
|
|
|
|
widget.close()
|
|
|
|
def test_on_save_emits_transcriptions_updated(self, qtbot: QtBot, transcription, transcription_service):
|
|
updated_signal = MagicMock()
|
|
widget = SpeakerIdentificationWidget(
|
|
transcription=transcription,
|
|
transcription_service=transcription_service,
|
|
transcriptions_updated_signal=updated_signal,
|
|
)
|
|
qtbot.addWidget(widget)
|
|
|
|
result = [{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 1000, 'text': 'Hi.'}]
|
|
widget.on_identification_finished(result)
|
|
|
|
new_id = uuid.uuid4()
|
|
with patch.object(widget.transcription_service, 'copy_transcription', return_value=new_id), \
|
|
patch.object(widget.transcription_service, 'update_transcription_as_completed'):
|
|
widget.on_save_button_clicked()
|
|
|
|
updated_signal.emit.assert_called_once_with(new_id)
|
|
|
|
widget.close()
|
|
|
|
def test_batch_processing_with_many_words(self):
|
|
"""Test batch processing when there are more than 200 words."""
|
|
# Create mock punctuation model
|
|
mock_punct_model = MagicMock()
|
|
mock_punct_model.predict.side_effect = lambda batch, chunk_size: [
|
|
(word.strip(), ".") for word in batch
|
|
]
|
|
|
|
# Create words list with 201 words (just enough to trigger batch processing)
|
|
words_list = [f"word{i}" for i in range(201)]
|
|
|
|
# Wrap predict method to match the expected signature
|
|
def predict_wrapper(batch, chunk_size, **kwargs):
|
|
return mock_punct_model.predict(batch, chunk_size=chunk_size)
|
|
|
|
# Call the generic batch processing function
|
|
result = process_in_batches(
|
|
items=words_list,
|
|
process_func=predict_wrapper
|
|
)
|
|
|
|
# Verify that predict was called multiple times (for batches)
|
|
assert mock_punct_model.predict.call_count >= 2, "Batch processing should split into multiple calls"
|
|
|
|
# Verify that each batch was processed with correct chunk_size
|
|
for call in mock_punct_model.predict.call_args_list:
|
|
args, kwargs = call
|
|
batch = args[0]
|
|
chunk_size = kwargs.get('chunk_size')
|
|
assert chunk_size <= 230, "Chunk size should not exceed 230"
|
|
assert len(batch) <= 200, "Batch size should not exceed 200"
|
|
|
|
# Verify result contains all words
|
|
assert len(result) == 201, "Result should contain all words"
|
|
|
|
def test_batch_processing_with_assertion_error_fallback(self):
|
|
"""Test error handling when AssertionError occurs during batch processing."""
|
|
# Create mock punctuation model - raise AssertionError on first batch, then succeed
|
|
mock_punct_model = MagicMock()
|
|
call_count = [0]
|
|
|
|
def predict_side_effect(batch, chunk_size):
|
|
call_count[0] += 1
|
|
# Raise AssertionError on first call (first batch)
|
|
if call_count[0] == 1:
|
|
raise AssertionError("Chunk size too large")
|
|
# Succeed on subsequent calls (smaller batches)
|
|
return [(word.strip(), ".") for word in batch]
|
|
|
|
mock_punct_model.predict.side_effect = predict_side_effect
|
|
|
|
# Create words list with 201 words (enough to trigger batch processing)
|
|
words_list = [f"word{i}" for i in range(201)]
|
|
|
|
# Wrap predict method to match the expected signature
|
|
def predict_wrapper(batch, chunk_size, **kwargs):
|
|
return mock_punct_model.predict(batch, chunk_size=chunk_size)
|
|
|
|
# Call the generic batch processing function
|
|
result = process_in_batches(
|
|
items=words_list,
|
|
process_func=predict_wrapper
|
|
)
|
|
|
|
# Verify that predict was called multiple times
|
|
# First call fails, then smaller batches succeed
|
|
assert mock_punct_model.predict.call_count > 1, "Should retry with smaller batches after AssertionError"
|
|
|
|
# Verify that smaller batches were used after the error
|
|
call_args_list = mock_punct_model.predict.call_args_list
|
|
# After the first failed call, subsequent calls should have smaller batches
|
|
for i, call in enumerate(call_args_list[1:], start=1): # Skip first failed call
|
|
args, kwargs = call
|
|
batch = args[0]
|
|
assert len(batch) <= 100, f"After AssertionError, batch size should be <= 100, got {len(batch)}"
|
|
|
|
# Verify result contains all words
|
|
assert len(result) == 201, "Result should contain all words"
|
|
|