buzz/tests/transcriber/whisper_cpp_file_transcriber_test.py

77 lines
2.6 KiB
Python

from typing import List
from unittest.mock import Mock
import pytest
from pytestqt.qtbot import QtBot
from buzz.model_loader import TranscriptionModel, ModelType, WhisperModelSize
from buzz.transcriber.transcriber import (
Segment,
FileTranscriptionOptions,
TranscriptionOptions,
Task,
FileTranscriptionTask,
)
from buzz.transcriber.whisper_cpp_file_transcriber import WhisperCppFileTranscriber
from tests.audio import test_audio_path
from tests.model_loader import get_model_path
class TestWhisperCppFileTranscriber:
@pytest.mark.parametrize(
"word_level_timings,expected_segments",
[
(
False,
[Segment(0, 6560, "Bienvenue dans Passe-Relle. Un podcast pensé pour")],
),
(True, [Segment(30, 330, "Bien"), Segment(330, 740, "venue")]),
],
)
def test_transcribe(
self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]
):
file_transcription_options = FileTranscriptionOptions(
file_paths=[test_audio_path]
)
transcription_options = TranscriptionOptions(
language="fr",
task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=TranscriptionModel(
model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY,
),
)
model_path = get_model_path(transcription_options.model)
transcriber = WhisperCppFileTranscriber(
task=FileTranscriptionTask(
file_path=test_audio_path,
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
model_path=model_path,
)
)
mock_progress = Mock(side_effect=lambda value: print("progress: ", value))
mock_completed = Mock()
mock_error = Mock()
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
transcriber.error.connect(mock_error)
with qtbot.wait_signal(transcriber.completed, timeout=10 * 60 * 1000):
transcriber.run()
mock_error.assert_not_called()
mock_progress.assert_called()
segments = [
segment
for segment in mock_completed.call_args[0][0]
if len(segment.text) > 0
]
for i, expected_segment in enumerate(expected_segments):
assert expected_segment.start == segments[i].start
assert expected_segment.end == segments[i].end
assert expected_segment.text in segments[i].text