From 7395de653f790ff8b0cd527ddb219bcd4e589b83 Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Fri, 16 Dec 2022 01:23:29 +0000 Subject: [PATCH] Add transcription viewer (#246) --- .coveragerc | 2 +- buzz/gui.py | 157 ++++++++++++++++++++++++++------------ buzz/transcriber.py | 127 +++++++++++++++++------------- tests/gui_test.py | 96 ++++++++++++++++------- tests/transcriber_test.py | 112 ++++++++++++++++----------- 5 files changed, 324 insertions(+), 170 deletions(-) diff --git a/.coveragerc b/.coveragerc index ea0d5b4..f161d58 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,4 +7,4 @@ omit = directory = coverage/html [report] -fail_under = 77 +fail_under = 76 diff --git a/buzz/gui.py b/buzz/gui.py index eee1650..009d7af 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import enum import logging import os @@ -9,15 +10,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union import humanize import sounddevice from PyQt6 import QtGui -from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread, +from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread, pyqtSlot, QThreadPool, QTimer, QUrl, pyqtSignal) from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon, QKeySequence, QPixmap, QTextCursor, QValidator) from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog, - QDialogButtonBox, QFileDialog, QGridLayout, - QHBoxLayout, QLabel, QLayout, QLineEdit, + QDialogButtonBox, QFileDialog, QGridLayout, QToolButton, + QLabel, QLayout, QLineEdit, QMainWindow, QMessageBox, QPlainTextEdit, - QProgressDialog, QPushButton, QVBoxLayout, + QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu, QWidget) from requests import get from whisper import tokenizer @@ -25,10 +26,10 @@ from whisper import tokenizer from .__version__ import VERSION from .model_loader import ModelLoader from .transcriber import (DEFAULT_WHISPER_TEMPERATURE, LOADED_WHISPER_DLL, - SUPPORTED_OUTPUT_FORMATS, OutputFormat, - RecordingTranscriber, Task, + SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat, + RecordingTranscriber, Segment, Task, WhisperCppFileTranscriber, WhisperFileTranscriber, - get_default_output_file_path) + get_default_output_file_path, segments_to_text, write_output) APP_NAME = 'Buzz' @@ -174,19 +175,9 @@ class QualityComboBox(QComboBox): class TextDisplayBox(QPlainTextEdit): """TextDisplayBox is a read-only textbox""" - os_styles = { - 'Darwin': '''QTextEdit { - border: 0; - }''' - } - def __init__(self, parent: Optional[QWidget], *args) -> None: super().__init__(parent, *args) self.setReadOnly(True) - self.setPlaceholderText('Click Record to begin...') - self.setStyleSheet( - '''QTextEdit { - } %s''' % get_platform_styles(self.os_styles)) class RecordButton(QPushButton): @@ -371,6 +362,8 @@ class FileTranscriberWidget(QWidget): model_loader: Optional[ModelLoader] = None transcriber_thread: Optional[QThread] = None transcribed = pyqtSignal() + transcription_options: FileTranscriptionOptions + is_transcribing = False def __init__(self, file_path: str, parent: Optional[QWidget]) -> None: super().__init__(parent) @@ -378,6 +371,8 @@ class FileTranscriberWidget(QWidget): layout = QGridLayout(self) self.settings = Settings(self) + self.transcription_options = FileTranscriptionOptions( + file_path=file_path) self.file_path = file_path @@ -462,15 +457,6 @@ class FileTranscriberWidget(QWidget): self.initial_prompt = initial_prompt def on_click_run(self): - default_path = get_default_output_file_path( - task=self.selected_task, input_file_path=self.file_path, - output_format=self.selected_output_format) - (output_file, _) = QFileDialog.getSaveFileName( - self, 'Save File', default_path, f'Text files (*.{self.selected_output_format.value})') - - if output_file == '': - return - use_whisper_cpp = self.settings.get_enable_ggml_inference( ) and self.selected_language is not None @@ -482,20 +468,18 @@ class FileTranscriberWidget(QWidget): self.model_loader = ModelLoader( name=model_name, use_whisper_cpp=use_whisper_cpp) + self.transcription_options = FileTranscriptionOptions( + file_path=self.file_path, language=self.selected_language, + task=self.selected_task, word_level_timings=self.enabled_word_level_timings, + temperature=self.temperature, initial_prompt=self.initial_prompt + ) + if use_whisper_cpp: self.file_transcriber = WhisperCppFileTranscriber( - file_path=self.file_path, language=self.selected_language, task=self.selected_task, - output_file_path=output_file, output_format=self.selected_output_format, - word_level_timings=self.enabled_word_level_timings, - ) + self.transcription_options) else: self.file_transcriber = WhisperFileTranscriber( - file_path=self.file_path, language=self.selected_language, task=self.selected_task, - output_file_path=output_file, output_format=self.selected_output_format, - word_level_timings=self.enabled_word_level_timings, - temperature=self.temperature, - initial_prompt=self.initial_prompt - ) + self.transcription_options) self.model_loader.moveToThread(self.transcriber_thread) self.file_transcriber.moveToThread(self.transcriber_thread) @@ -515,6 +499,7 @@ class FileTranscriberWidget(QWidget): self.model_loader.finished.connect(self.model_loader.deleteLater) # Run the file transcriber after the model loads + self.model_loader.finished.connect(self.on_model_loaded) self.model_loader.finished.connect(self.file_transcriber.run) self.file_transcriber.progress.connect( @@ -527,6 +512,9 @@ class FileTranscriberWidget(QWidget): self.transcriber_thread.start() + def on_model_loaded(self): + self.is_transcribing = True + def on_download_model_progress(self, progress: Tuple[int, int]): (current_size, total_size) = progress @@ -545,22 +533,26 @@ class FileTranscriberWidget(QWidget): self.reset_transcriber_controls() def on_transcriber_progress(self, progress: Tuple[int, int]): - logging.debug('received progress = %s', progress) (current_size, total_size) = progress - # Create a dialog - if self.transcriber_progress_dialog is None: - self.transcriber_progress_dialog = TranscriberProgressDialog( - file_path=self.file_path, total_size=total_size, parent=self) - self.transcriber_progress_dialog.canceled.connect( - self.on_cancel_transcriber_progress_dialog) + if self.is_transcribing: + # Create a dialog + if self.transcriber_progress_dialog is None: + self.transcriber_progress_dialog = TranscriberProgressDialog( + file_path=self.file_path, total_size=total_size, parent=self) + self.transcriber_progress_dialog.canceled.connect( + self.on_cancel_transcriber_progress_dialog) + else: + # Update the progress of the dialog unless it has + # been canceled before this progress update arrived + self.transcriber_progress_dialog.update_progress(current_size) - # Update the progress of the dialog unless it has - # been canceled before this progress update arrived - if self.transcriber_progress_dialog is not None: - self.transcriber_progress_dialog.update_progress(current_size) + @pyqtSlot(tuple) + def on_transcriber_complete(self, result: Tuple[int, List[Segment]]): + exit_code, segments = result + + self.is_transcribing = False - def on_transcriber_complete(self, exit_code: int): if self.transcriber_progress_dialog is not None: self.transcriber_progress_dialog.reset() if exit_code != 0: @@ -569,6 +561,10 @@ class FileTranscriberWidget(QWidget): self.reset_transcriber_controls() self.transcribed.emit() + TranscriptionViewerWidget( + transcription_options=self.transcription_options, + segments=segments, parent=self, flags=Qt.WindowType.Window).show() + def on_cancel_transcriber_progress_dialog(self): if self.file_transcriber is not None: self.file_transcriber.stop() @@ -590,6 +586,70 @@ class FileTranscriberWidget(QWidget): self.enabled_word_level_timings = value == Qt.CheckState.Checked.value +class TranscriptionViewerWidget(QWidget): + segments: List[Segment] + transcription_options: FileTranscriptionOptions + + def __init__( + self, transcription_options: FileTranscriptionOptions, segments: List[Segment], + parent: Optional['QWidget'] = None, flags: Qt.WindowType = Qt.WindowType.Widget, + ) -> None: + super().__init__(parent, flags) + self.segments = segments + self.transcription_options = transcription_options + + self.setMinimumWidth(500) + self.setMinimumHeight(500) + + self.setWindowTitle( + f'Transcription - {get_short_file_path(transcription_options.file_path)}') + + layout = QVBoxLayout(self) + + text = segments_to_text(segments) + + self.text_box = TextDisplayBox(self) + self.text_box.setPlainText(text) + + layout.addWidget(self.text_box) + + buttons_layout = QHBoxLayout() + buttons_layout.addStretch() + + menu = QMenu() + actions = [QAction(text=output_format.value.upper(), parent=self) + for output_format in OutputFormat] + menu.addActions(actions) + + menu.triggered.connect(self.on_menu_triggered) + + export_button = QPushButton(self) + export_button.setText('Export') + export_button.setMenu(menu) + + buttons_layout.addWidget(export_button) + layout.addLayout(buttons_layout) + + self.setLayout(layout) + + def on_menu_triggered(self, action: QAction): + output_format = OutputFormat[action.text()] + + default_path = get_default_output_file_path( + task=self.transcription_options.task, + input_file_path=self.transcription_options.file_path, + output_format=output_format) + + (output_file_path, _) = QFileDialog.getSaveFileName( + self, 'Save File', default_path, f'Text files (*.{output_format.value})') + + if output_file_path == '': + return + + write_output(path=output_file_path, segments=self.segments, + should_open=True, output_format=output_format) + + class Settings(QSettings): _ENABLE_GGML_INFERENCE = 'enable_ggml_inference' @@ -686,6 +746,7 @@ class RecordingTranscriberWidget(QWidget): self.open_advanced_settings) self.text_box = TextDisplayBox(self) + self.text_box.setPlaceholderText('Click Record to begin...') widgets = [ ((0, 5, FormLabel('Task:', self)), (5, 7, self.tasks_combo_box)), diff --git a/buzz/transcriber.py b/buzz/transcriber.py index f80fd71..ebe9c62 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -1,10 +1,12 @@ import ctypes import datetime import enum +import json import logging import multiprocessing import os import platform +import queue import re import subprocess import sys @@ -195,28 +197,33 @@ class OutputFormat(enum.Enum): VTT = 'vtt' +@dataclass +class FileTranscriptionOptions: + file_path: str + language: Optional[str] = None + task: Task = Task.TRANSCRIBE + word_level_timings: bool = False + temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE + initial_prompt: str = '' + + class WhisperCppFileTranscriber(QObject): progress = pyqtSignal(tuple) # (current, total) - completed = pyqtSignal(int) + completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment]) error = pyqtSignal(str) duration_audio_ms = sys.maxsize # max int segments: List[Segment] running = False def __init__( - self, language: Optional[str], task: Task, file_path: str, - output_file_path: str, output_format: OutputFormat, - word_level_timings: bool, open_file_on_complete=True, + self, transcription_options: FileTranscriptionOptions, parent: Optional['QObject'] = None) -> None: super().__init__(parent) - self.file_path = file_path - self.output_file_path = output_file_path - self.language = language - self.task = task - self.open_file_on_complete = open_file_on_complete - self.output_format = output_format - self.word_level_timings = word_level_timings + self.file_path = transcription_options.file_path + self.language = transcription_options.language + self.task = transcription_options.task + self.word_level_timings = transcription_options.word_level_timings self.segments = [] self.process = QProcess(self) @@ -228,8 +235,8 @@ class WhisperCppFileTranscriber(QObject): self.running = True logging.debug( - 'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model_path = %s', - self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path) + 'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s', + self.file_path, self.language, self.task, model_path) wav_file = tempfile.mktemp()+'.wav' ( @@ -258,10 +265,8 @@ class WhisperCppFileTranscriber(QObject): if status == QProcess.ExitStatus.NormalExit: self.progress.emit( (self.duration_audio_ms, self.duration_audio_ms)) - write_output( - self.output_file_path, self.segments, self.open_file_on_complete, self.output_format) - self.completed.emit(status == self.process.exitCode()) + self.completed.emit((self.process.exitCode(), self.segments)) self.running = False def stop(self): @@ -282,7 +287,7 @@ class WhisperCppFileTranscriber(QObject): segment = Segment(start, end, text.strip()) self.segments.append(segment) self.progress.emit((end, self.duration_audio_ms)) - except UnicodeDecodeError: + except (UnicodeDecodeError, ValueError): pass def parse_timings(self, timings: str) -> Tuple[int, int]: @@ -315,44 +320,39 @@ class WhisperFileTranscriber(QObject): current_process: multiprocessing.Process progress = pyqtSignal(tuple) # (current, total) - completed = pyqtSignal(int) + completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment]) error = pyqtSignal(str) running = False read_line_thread: Optional[Thread] = None - def __init__(self, - language: Optional[str], task: Task, file_path: str, - output_file_path: str, output_format: OutputFormat, - word_level_timings: bool, temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE, - initial_prompt: str = '', open_file_on_complete=True, parent: Optional['QObject'] = None) -> None: + def __init__( + self, transcription_options: FileTranscriptionOptions, + parent: Optional['QObject'] = None) -> None: super().__init__(parent) - self.file_path = file_path - self.output_file_path = output_file_path - self.language = language - self.task = task - self.open_file_on_complete = open_file_on_complete - self.output_format = output_format - self.word_level_timings = word_level_timings - self.temperature = temperature - self.initial_prompt = initial_prompt + self.file_path = transcription_options.file_path + self.language = transcription_options.language + self.task = transcription_options.task + self.word_level_timings = transcription_options.word_level_timings + self.temperature = transcription_options.temperature + self.initial_prompt = transcription_options.initial_prompt + self.segments = [] @pyqtSlot(str) def run(self, model_path: str): self.running = True time_started = datetime.datetime.now() logging.debug( - 'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model path = %s, temperature = %s, initial prompt length = %s', - self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path, self.temperature, len(self.initial_prompt)) + 'Starting whisper file transcription, file path = %s, language = %s, task = %s, model path = %s, temperature = %s, initial prompt length = %s', + self.file_path, self.language, self.task, model_path, self.temperature, len(self.initial_prompt)) recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False) + self.current_process = multiprocessing.Process( target=transcribe_whisper, args=( send_pipe, model_path, self.file_path, - self.language, self.task, self.output_file_path, - self.open_file_on_complete, self.output_format, - self.word_level_timings, + self.language, self.task, self.word_level_timings, self.temperature, self.initial_prompt )) @@ -373,7 +373,9 @@ class WhisperFileTranscriber(QObject): self.read_line_thread.join() - self.completed.emit(self.current_process.exitcode) + if self.current_process.exitcode != 0: + self.completed.emit((self.current_process.exitcode, [])) + self.running = False def stop(self): @@ -381,6 +383,17 @@ class WhisperFileTranscriber(QObject): self.current_process.terminate() def on_whisper_stdout(self, line: str): + if line.startswith('segments = '): + segments_dict = json.loads(line[11:]) + segments = [Segment( + start=segment.get('start'), + end=segment.get('end'), + text=segment.get('text'), + ) for segment in segments_dict] + self.current_process.join() + self.completed.emit((self.current_process.exitcode, segments)) + return + try: progress = int(line.split('|')[0].strip().strip('%')) self.progress.emit((progress, 100)) @@ -398,8 +411,7 @@ class WhisperFileTranscriber(QObject): def transcribe_whisper( stderr_conn: Connection, model_path: str, file_path: str, - language: Optional[str], task: Task, output_file_path: str, - open_file_on_complete: bool, output_format: OutputFormat, + language: Optional[str], task: Task, word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str): with pipe_stderr(stderr_conn): model = whisper.load_model(model_path) @@ -418,15 +430,15 @@ def transcribe_whisper( whisper_segments = stable_whisper.group_word_timestamps( result) if word_level_timings else result.get('segments') - segments = map( - lambda segment: Segment( - start=segment.get('start')*1000, # s to ms - end=segment.get('end')*1000, # s to ms - text=segment.get('text')), - whisper_segments) - - write_output(output_file_path, list( - segments), open_file_on_complete, output_format) + segments = [ + Segment( + start=int(segment.get('start')*1000), + end=int(segment.get('end')*1000), + text=segment.get('text'), + ) for segment in whisper_segments] + segments_json = json.dumps( + segments, ensure_ascii=True, default=vars) + sys.stderr.write(f'segments = {segments_json}\n') def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat): @@ -435,8 +447,11 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f with open(path, 'w', encoding='utf-8') as file: if output_format == OutputFormat.TXT: - for segment in segments: - file.write(segment.text + ' ') + for (i, segment) in enumerate(segments): + file.write(segment.text) + if i < len(segments)-1: + file.write(' ') + file.write('\n') elif output_format == OutputFormat.VTT: file.write('WEBVTT\n\n') @@ -460,6 +475,16 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f subprocess.call([opener, path]) +def segments_to_text(segments: List[Segment]) -> str: + result = '' + for (i, segment) in enumerate(segments): + result += f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n' + result += f'{segment.text}' + if i < len(segments)-1: + result += '\n\n' + return result + + def to_timestamp(ms: float) -> str: hr = int(ms / (1000*60*60)) ms = ms - hr * (1000*60*60) diff --git a/tests/gui_test.py b/tests/gui_test.py index 9b15be9..c45d049 100644 --- a/tests/gui_test.py +++ b/tests/gui_test.py @@ -1,21 +1,24 @@ import logging import os import pathlib +from typing import Any, Callable +import platform from unittest.mock import Mock, patch import pytest import sounddevice from PyQt6.QtCore import Qt, QCoreApplication from PyQt6.QtGui import (QValidator) +from PyQt6.QtWidgets import (QPushButton) from pytestqt.qtbot import QtBot from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application, AudioDevicesComboBox, DownloadModelProgressDialog, FileTranscriberWidget, LanguagesComboBox, MainWindow, OutputFormatsComboBox, Quality, QualityComboBox, - Settings, TemperatureValidator, - TranscriberProgressDialog) -from buzz.transcriber import OutputFormat + Settings, TemperatureValidator, TextDisplayBox, + TranscriberProgressDialog, TranscriptionViewerWidget) +from buzz.transcriber import FileTranscriptionOptions, OutputFormat, Segment class TestApplication: @@ -142,25 +145,29 @@ class TestTranscriberProgressDialog: class TestDownloadModelProgressDialog: - dialog = DownloadModelProgressDialog(total_size=1234567, parent=None) + def test_should_show_dialog(self, qtbot: QtBot): + dialog = DownloadModelProgressDialog(total_size=1234567, parent=None) + qtbot.add_widget(dialog) + assert dialog.labelText() == 'Downloading resources (0%, unknown time remaining)' - def test_should_show_dialog(self): - assert self.dialog.labelText() == 'Downloading resources (0%, unknown time remaining)' + def test_should_update_label_on_progress(self, qtbot: QtBot): + dialog = DownloadModelProgressDialog(total_size=1234567, parent=None) + qtbot.add_widget(dialog) + dialog.setValue(0) - def test_should_update_label_on_progress(self): - self.dialog.setValue(0) - - self.dialog.setValue(12345) - assert self.dialog.labelText().startswith( + dialog.setValue(12345) + assert dialog.labelText().startswith( 'Downloading resources (1.00%') - self.dialog.setValue(123456) - assert self.dialog.labelText().startswith( + dialog.setValue(123456) + assert dialog.labelText().startswith( 'Downloading resources (10.00%') # Other windows should not be processing while models are being downloaded - def test_should_be_an_application_modal(self): - assert self.dialog.windowModality() == Qt.WindowModality.ApplicationModal + def test_should_be_an_application_modal(self, qtbot: QtBot): + dialog = DownloadModelProgressDialog(total_size=1234567, parent=None) + qtbot.add_widget(dialog) + assert dialog.windowModality() == Qt.WindowModality.ApplicationModal class TestFormatsComboBox: @@ -177,21 +184,32 @@ class TestMainWindow: assert main_window is not None +def wait_until(callback: Callable[[], Any], timeout=0): + while True: + try: + QCoreApplication.processEvents() + callback() + return + except AssertionError: + pass + + class TestFileTranscriberWidget: - def test_should_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path): + @pytest.mark.skipif(condition=platform.system() == 'Windows', reason='Waiting for signal crashes process on Windows') + def test_should_transcribe(self, qtbot: QtBot): widget = FileTranscriberWidget( file_path='testdata/whisper-french.mp3', parent=None) qtbot.addWidget(widget) - output_file_path = tmp_path / 'whisper.txt' + # Waiting for a "transcribed" signal seems to work more consistently + # than checking for the opening of a TranscriptionViewerWidget. + # See also: https://github.com/pytest-dev/pytest-qt/issues/313 + with qtbot.wait_signal(widget.transcribed, timeout=30*1000): + qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton) - with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock, - qtbot.wait_signal(widget.transcribed, timeout=10*1000)): - save_file_name_mock.return_value = (output_file_path, '') - widget.run_button.click() - - output_file = open(output_file_path, 'r', encoding='utf-8') - assert 'Bienvenue dans Passe' in output_file.read() + transcription_viewer = widget.findChild(TranscriptionViewerWidget) + assert isinstance(transcription_viewer, TranscriptionViewerWidget) + assert len(transcription_viewer.segments) > 0 @pytest.mark.skip(reason="transcription_started callback sometimes not getting called until all progress events are emitted") def test_should_transcribe_and_stop(self, qtbot: QtBot, tmp_path: pathlib.Path): @@ -202,13 +220,12 @@ class TestFileTranscriberWidget: output_file_path = tmp_path / 'whisper.txt' with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock): - save_file_name_mock.return_value = (output_file_path, '') + save_file_name_mock.return_value = (str(output_file_path), '') widget.run_button.click() def transcription_started(): QCoreApplication.processEvents() assert widget.transcriber_progress_dialog is not None - logging.debug('asserted value = %s', widget.transcriber_progress_dialog.value()) assert widget.transcriber_progress_dialog.value() > 0 qtbot.wait_until(transcription_started, timeout=30*1000) @@ -270,3 +287,30 @@ class TestTemperatureValidator: ]) def test_should_validate_temperature(self, text: str, state: QValidator.State): assert self.validator.validate(text, 0)[0] == state + + +class TestTranscriptionViewerWidget: + widget = TranscriptionViewerWidget( + transcription_options=FileTranscriptionOptions( + file_path='testdata/whisper-french.mp3'), + segments=[Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')]) + + def test_should_display_segments(self): + assert self.widget.windowTitle() == 'Transcription - whisper-french.mp3' + + text_display_box = self.widget.findChild(TextDisplayBox) + assert isinstance(text_display_box, TextDisplayBox) + assert text_display_box.toPlainText( + ) == '00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans' + + def test_should_export_segments(self, tmp_path: pathlib.Path): + export_button = self.widget.findChild(QPushButton) + assert isinstance(export_button, QPushButton) + + output_file_path = tmp_path / 'whisper.txt' + with patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock: + save_file_name_mock.return_value = (str(output_file_path), '') + export_button.menu().actions()[0].trigger() + + output_file = open(output_file_path, 'r', encoding='utf-8') + assert 'Bien venue dans' in output_file.read() diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index 5e3367f..9454cca 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -1,7 +1,9 @@ +import logging import os import pathlib import tempfile import time +from typing import List from unittest.mock import Mock from PyQt6.QtCore import QCoreApplication @@ -9,11 +11,11 @@ import pytest from pytestqt.qtbot import QtBot from buzz.model_loader import ModelLoader -from buzz.transcriber import (OutputFormat, RecordingTranscriber, Task, +from buzz.transcriber import (FileTranscriptionOptions, OutputFormat, RecordingTranscriber, Segment, Task, WhisperCpp, WhisperCppFileTranscriber, WhisperFileTranscriber, get_default_output_file_path, to_timestamp, - whisper_cpp_params) + whisper_cpp_params, write_output) def get_model_path(model_name: str, use_whisper_cpp: bool) -> str: @@ -40,33 +42,31 @@ class TestRecordingTranscriber: class TestWhisperCppFileTranscriber: @pytest.mark.parametrize( - 'task,output_text', + 'word_level_timings,expected_segments', [ - (Task.TRANSCRIBE, 'Bienvenue dans Passe'), - (Task.TRANSLATE, 'Welcome to Passe-Relle'), + (False, [Segment(0, 1840, 'Bienvenue dans Passe Relle.')]), + (True, [Segment(30, 280, 'Bien'), Segment(280, 630, 'venue')]) ]) - def test_transcribe(self, qtbot, tmp_path: pathlib.Path, task: Task, output_text: str): - output_file_path = tmp_path / 'whisper_cpp.txt' - if os.path.exists(output_file_path): - os.remove(output_file_path) + def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]): + transcription_options = FileTranscriptionOptions( + language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3', + word_level_timings=word_level_timings) model_path = get_model_path('tiny', True) transcriber = WhisperCppFileTranscriber( - language='fr', task=task, file_path='testdata/whisper-french.mp3', - output_file_path=output_file_path.as_posix(), output_format=OutputFormat.TXT, - open_file_on_complete=False, - word_level_timings=False) + transcription_options=transcription_options) mock_progress = Mock() - with qtbot.waitSignal(transcriber.completed, timeout=10*60*1000): - transcriber.progress.connect(mock_progress) + mock_completed = Mock() + transcriber.progress.connect(mock_progress) + transcriber.completed.connect(mock_completed) + with qtbot.waitSignal(transcriber.completed, timeout=10 * 60 * 1000): transcriber.run(model_path) - assert os.path.isfile(output_file_path) - - output_file = open(output_file_path, 'r', encoding='utf-8') - assert output_text in output_file.read() - mock_progress.assert_called() + exit_code, segments = mock_completed.call_args[0][0] + assert exit_code is 0 + for expected_segment in expected_segments: + assert expected_segment in segments class TestWhisperFileTranscriber: @@ -82,36 +82,31 @@ class TestWhisperFileTranscriber: assert srt.endswith('.srt') @pytest.mark.parametrize( - 'word_level_timings,output_format,output_text', + 'word_level_timings,expected_segments', [ - (False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle'), - (False, OutputFormat.SRT, - '1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle'), - (False, OutputFormat.VTT, - 'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe'), - (True, OutputFormat.SRT, - '1\n00:00:00.040 --> 00:00:00.299\n Bien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n3\n00:00:00.329 --> 00:00:00.429\n P\n\n4\n00:00:00.429 --> 00:00:00.589\nasse-'), + (False, [ + Segment( + 0, 6560, + ' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances'), + ]), + (True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')]) ]) - def test_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str): - output_file_path = tmp_path / f'whisper.{output_format.value.lower()}' - + def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]): model_path = get_model_path('tiny', False) mock_progress = Mock() - transcriber = WhisperFileTranscriber( + mock_completed = Mock() + transcription_options = FileTranscriptionOptions( language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3', - output_file_path=output_file_path.as_posix(), output_format=output_format, - open_file_on_complete=False, word_level_timings=word_level_timings) + + transcriber = WhisperFileTranscriber( + transcription_options=transcription_options) transcriber.progress.connect(mock_progress) - with qtbot.wait_signal(transcriber.completed, timeout=10*6000): + transcriber.completed.connect(mock_completed) + with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000): transcriber.run(model_path) - assert os.path.isfile(output_file_path) - - output_file = open(output_file_path, 'r', encoding='utf-8') - assert output_text in output_file.read() - QCoreApplication.processEvents() # Reports progress at 0, 0 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'), + (OutputFormat.VTT, + 'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'), + ]) +def test_write_output(tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str): + output_file_path = tmp_path / 'whisper.txt' + segments = [Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')] + + write_output(path=str(output_file_path), segments=segments, + should_open=False, output_format=output_format) + + output_file = open(output_file_path, 'r', encoding='utf-8') + assert output_text == output_file.read()