buzz/tests/transcriber_test.py
2023-04-28 22:28:05 +01:00

227 lines
11 KiB
Python

import os
import pathlib
import platform
import tempfile
import time
from typing import List
from unittest.mock import Mock, patch
import pytest
from PyQt6.QtCore import QThread
from pytestqt.qtbot import QtBot
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber,
Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
WhisperFileTranscriber,
get_default_output_file_path, to_timestamp,
whisper_cpp_params, write_output, TranscriptionOptions)
from tests.mock_sounddevice import MockInputStream
from tests.model_loader import get_model_path
class TestRecordingTranscriber:
@pytest.mark.skip(reason='Hanging')
def test_should_transcribe(self, qtbot):
thread = QThread()
transcription_model = TranscriptionModel(model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY)
transcriber = RecordingTranscriber(transcription_options=TranscriptionOptions(
model=transcription_model, language='fr', task=Task.TRANSCRIBE),
input_device_index=0, sample_rate=16_000)
transcriber.moveToThread(thread)
thread.finished.connect(thread.deleteLater)
mock_transcription = Mock()
transcriber.transcription.connect(mock_transcription)
transcriber.finished.connect(thread.quit)
transcriber.finished.connect(transcriber.deleteLater)
with patch('sounddevice.InputStream', side_effect=MockInputStream), patch(
'sounddevice.check_input_settings'), qtbot.wait_signal(transcriber.transcription, timeout=60 * 1000):
thread.start()
with qtbot.wait_signal(thread.finished, timeout=60 * 1000):
transcriber.stop_recording()
text = mock_transcription.call_args[0][0]
assert 'Bienvenue dans Passe' in text
@pytest.mark.skipif(platform.system() == 'Windows', reason='whisper_cpp not printing segments on Windows')
class TestWhisperCppFileTranscriber:
@pytest.mark.parametrize(
'word_level_timings,expected_segments',
[
(False, [Segment(0, 6560,
'Bienvenue dans Passe-Relle. Un podcast pensé pour')]),
(True, [Segment(30, 330, 'Bien'), Segment(330, 740, 'venue')])
])
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
file_transcription_options = FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3'])
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY))
model_path = get_model_path(transcription_options.model)
transcriber = WhisperCppFileTranscriber(
task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3',
transcription_options=transcription_options,
file_transcription_options=file_transcription_options, model_path=model_path))
mock_progress = Mock()
mock_completed = Mock()
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
with qtbot.waitSignal(transcriber.completed, timeout=10 * 60 * 1000):
transcriber.run()
mock_progress.assert_called()
segments = [segment for segment in mock_completed.call_args[0][0] if len(segment.text) > 0]
for i, expected_segment in enumerate(expected_segments):
assert expected_segment.start == segments[i].start
assert expected_segment.end == segments[i].end
assert expected_segment.text in segments[i].text
class TestWhisperFileTranscriber:
def test_default_output_file(self):
srt = get_default_output_file_path(
Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.TXT)
assert srt.startswith('/a/b/c (Translated on ')
assert srt.endswith('.txt')
srt = get_default_output_file_path(
Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.SRT)
assert srt.startswith('/a/b/c (Translated on ')
assert srt.endswith('.srt')
@pytest.mark.parametrize(
'word_level_timings,expected_segments,model,check_progress',
[
(False, [Segment(0, 6560,
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances')],
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')],
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
(False, [Segment(0, 8517,
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances '
'et des apprenances de français.')],
TranscriptionModel(model_type=ModelType.HUGGING_FACE,
hugging_face_model_id='openai/whisper-tiny'), False),
pytest.param(
False, [Segment(start=0, end=8400,
text=' Bienvenue dans Passrel, un podcast pensé pour éveiller la curiosité des apprenances et des apprenances de français.')],
TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY), True,
marks=pytest.mark.skipif(platform.system() == 'Darwin',
reason='Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087')
)
])
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment],
model: TranscriptionModel, check_progress):
mock_progress = Mock()
mock_completed = Mock()
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=model)
model_path = get_model_path(transcription_options.model)
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'testdata/whisper-french.mp3'))
file_transcription_options = FileTranscriptionOptions(
file_paths=[file_path])
transcriber = WhisperFileTranscriber(
task=FileTranscriptionTask(transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
file_path=file_path, model_path=model_path))
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
with qtbot.wait_signal(transcriber.progress, timeout=10 * 6000), qtbot.wait_signal(transcriber.completed,
timeout=10 * 6000):
transcriber.run()
# Skip checking progress...
# if check_progress:
# # Reports progress at 0, 0<progress<100, and 100
# assert any(
# [call_args.args[0] == (0, 100) for call_args in mock_progress.call_args_list])
# assert any(
# [call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
# assert any(
# [(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in
# mock_progress.call_args_list])
mock_completed.assert_called()
segments = mock_completed.call_args[0][0]
assert len(segments) >= len(expected_segments)
for (i, expected_segment) in enumerate(expected_segments):
assert segments[i] == expected_segment
@pytest.mark.skip()
def test_transcribe_stop(self):
output_file_path = os.path.join(tempfile.gettempdir(), 'whisper.txt')
if os.path.exists(output_file_path):
os.remove(output_file_path)
file_transcription_options = FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3'])
transcription_options = TranscriptionOptions(
language='fr', task=Task.TRANSCRIBE, word_level_timings=False,
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
model_path = get_model_path(transcription_options.model)
transcriber = WhisperFileTranscriber(
task=FileTranscriptionTask(model_path=model_path, transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
file_path='testdata/whisper-french.mp3'))
transcriber.run()
time.sleep(1)
transcriber.stop()
# Assert that file was not created
assert os.path.isfile(output_file_path) is False
class TestToTimestamp:
def test_to_timestamp(self):
assert to_timestamp(0) == '00:00:00.000'
assert to_timestamp(123456789) == '34:17:36.789'
class TestWhisperCpp:
def test_transcribe(self):
transcription_options = TranscriptionOptions(
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
model_path = get_model_path(transcription_options.model)
whisper_cpp = WhisperCpp(model=model_path)
params = whisper_cpp_params(
language='fr', task=Task.TRANSCRIBE, word_level_timings=False)
result = whisper_cpp.transcribe(
audio='testdata/whisper-french.mp3', params=params)
assert 'Bienvenue dans Passe' in result['text']
@pytest.mark.parametrize(
'output_format,output_text',
[
(OutputFormat.TXT, 'Bien venue dans\n'),
(
OutputFormat.SRT,
'1\n00:00:00,040 --> 00:00:00,299\nBien\n\n2\n00:00:00,299 --> 00:00:00,329\nvenue dans\n\n'),
(OutputFormat.VTT,
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
])
def test_write_output(tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str):
output_file_path = tmp_path / 'whisper.txt'
segments = [Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')]
write_output(path=str(output_file_path), segments=segments, output_format=output_format)
output_file = open(output_file_path, 'r', encoding='utf-8')
assert output_text == output_file.read()