diff --git a/.coveragerc b/.coveragerc index 9f03fdf..7772715 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,4 +7,4 @@ omit = directory = coverage/html [report] -fail_under = 70 +fail_under = 74 diff --git a/buzz/gui.py b/buzz/gui.py index 7201ffa..7fde263 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -119,9 +119,9 @@ class LanguagesComboBox(QComboBox): self.currentIndexChanged.connect(self.on_index_changed) default_language_key = default_language if default_language != '' else None - default_language_index = next((i for i, lang in enumerate(self.languages) - if lang[0] == default_language_key), 0) - self.setCurrentIndex(default_language_index) + for i, lang in enumerate(self.languages): + if lang[0] == default_language_key: + self.setCurrentIndex(i) def on_index_changed(self, index: int): self.languageChanged.emit(self.languages[index][0]) @@ -361,6 +361,11 @@ class FileTranscriberWidget(QWidget): def on_word_level_timings_changed(self, value: int): self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value + def closeEvent(self, event: QtGui.QCloseEvent) -> None: + if self.transcriber_thread is not None: + self.transcriber_thread.wait() + super().closeEvent(event) + class TranscriptionViewerWidget(QWidget): transcription_task: FileTranscriptionTask diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 9350052..821e9fc 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -70,9 +70,8 @@ class ModelLoader(QObject): file_path = os.path.join(root_dir, f'ggml-model-whisper-{model_name}.bin') expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name] self.download_model(url, file_path, expected_sha256) - return - if self.model_type == ModelType.WHISPER: + elif self.model_type == ModelType.WHISPER: root_dir = os.getenv( "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper") @@ -82,9 +81,8 @@ class ModelLoader(QObject): file_path = os.path.join(root_dir, os.path.basename(url)) expected_sha256 = url.split('/')[-2] self.download_model(url, file_path, expected_sha256) - return - if self.model_type == ModelType.HUGGING_FACE: + else: # ModelType.HUGGING_FACE: self.progress.emit((0, 100)) try: @@ -95,8 +93,9 @@ class ModelLoader(QObject): return self.progress.emit((100, 100)) - self.finished.emit(self.hugging_face_model_id) - return + file_path = self.hugging_face_model_id + + self.finished.emit(file_path) def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]): try: @@ -108,14 +107,12 @@ class ModelLoader(QObject): if os.path.isfile(file_path): if expected_sha256 is None: - self.finished.emit(file_path) - return + return file_path model_bytes = open(file_path, "rb").read() model_sha256 = hashlib.sha256(model_bytes).hexdigest() if model_sha256 == expected_sha256: - self.finished.emit(file_path) - return + return file_path else: warnings.warn( f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file") @@ -141,7 +138,7 @@ class ModelLoader(QObject): "Model has been downloaded but the SHA256 checksum does not match. Please retry loading the " "model.") - self.finished.emit(file_path) + return file_path except RuntimeError as exc: self.error.emit(str(exc)) logging.exception('') diff --git a/tests/gui_test.py b/tests/gui_test.py index e724f38..d822c78 100644 --- a/tests/gui_test.py +++ b/tests/gui_test.py @@ -1,7 +1,6 @@ import logging import os.path import pathlib -import platform from unittest.mock import Mock, patch import pytest @@ -12,8 +11,7 @@ from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget, QApplication from pytestqt.qtbot import QtBot from buzz.cache import TasksCache -from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application, - AudioDevicesComboBox, DownloadModelProgressDialog, +from buzz.gui import (AboutDialog, AdvancedSettingsDialog, AudioDevicesComboBox, DownloadModelProgressDialog, FileTranscriberWidget, LanguagesComboBox, MainWindow, RecordingTranscriberWidget, TemperatureValidator, TextDisplayBox, @@ -25,30 +23,26 @@ from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, from tests.mock_sounddevice import MockInputStream -class TestApplication: - # FIXME: this seems to break the tests if not run?? - app = Application() - - def test_should_open_application(self): - assert self.app is not None - - class TestLanguagesComboBox: - languagesComboxBox = LanguagesComboBox('en') - def test_should_show_sorted_whisper_languages(self): - assert self.languagesComboxBox.itemText(0) == 'Detect Language' - assert self.languagesComboxBox.itemText(10) == 'Belarusian' - assert self.languagesComboxBox.itemText(20) == 'Dutch' - assert self.languagesComboxBox.itemText(30) == 'Gujarati' - assert self.languagesComboxBox.itemText(40) == 'Japanese' - assert self.languagesComboxBox.itemText(50) == 'Lithuanian' + def test_should_show_sorted_whisper_languages(self, qtbot): + languages_combox_box = LanguagesComboBox('en') + qtbot.add_widget(languages_combox_box) + assert languages_combox_box.itemText(0) == 'Detect Language' + assert languages_combox_box.itemText(10) == 'Belarusian' + assert languages_combox_box.itemText(20) == 'Dutch' + assert languages_combox_box.itemText(30) == 'Gujarati' + assert languages_combox_box.itemText(40) == 'Japanese' + assert languages_combox_box.itemText(50) == 'Lithuanian' - def test_should_select_en_as_default_language(self): - assert self.languagesComboxBox.currentText() == 'English' + def test_should_select_en_as_default_language(self, qtbot): + languages_combox_box = LanguagesComboBox('en') + qtbot.add_widget(languages_combox_box) + assert languages_combox_box.currentText() == 'English' - def test_should_select_detect_language_as_default(self): + def test_should_select_detect_language_as_default(self, qtbot): languages_combo_box = LanguagesComboBox(None) + qtbot.add_widget(languages_combo_box) assert languages_combo_box.currentText() == 'Detect Language' @@ -185,17 +179,16 @@ class TestMainWindow: class TestFileTranscriberWidget: - widget = FileTranscriberWidget( - file_paths=['testdata/whisper-french.mp3'], parent=None) - def test_should_set_window_title(self, qtbot: QtBot): - qtbot.addWidget(self.widget) - assert self.widget.windowTitle() == 'whisper-french.mp3' + widget = FileTranscriberWidget( + file_paths=['testdata/whisper-french.mp3'], parent=None) + qtbot.add_widget(widget) + assert widget.windowTitle() == 'whisper-french.mp3' def test_should_emit_triggered_event(self, qtbot: QtBot): widget = FileTranscriberWidget( file_paths=['testdata/whisper-french.mp3'], parent=None) - qtbot.addWidget(widget) + qtbot.add_widget(widget) mock_triggered = Mock() widget.triggered.connect(mock_triggered) @@ -254,31 +247,41 @@ class TestTemperatureValidator: class TestTranscriptionViewerWidget: - widget = TranscriptionViewerWidget( - transcription_task=FileTranscriptionTask( - id=0, - file_path='testdata/whisper-french.mp3', - file_transcription_options=FileTranscriptionOptions( - file_paths=['testdata/whisper-french.mp3']), - transcription_options=TranscriptionOptions(), - segments=[Segment(40, 299, 'Bien'), - Segment(299, 329, 'venue dans')], - model_path='')) def test_should_display_segments(self, qtbot: QtBot): - qtbot.add_widget(self.widget) + widget = TranscriptionViewerWidget( + transcription_task=FileTranscriptionTask( + id=0, + file_path='testdata/whisper-french.mp3', + file_transcription_options=FileTranscriptionOptions( + file_paths=['testdata/whisper-french.mp3']), + transcription_options=TranscriptionOptions(), + segments=[Segment(40, 299, 'Bien'), + Segment(299, 329, 'venue dans')], + model_path='')) + qtbot.add_widget(widget) - assert self.widget.windowTitle() == 'whisper-french.mp3' + assert widget.windowTitle() == 'whisper-french.mp3' - text_display_box = self.widget.findChild(TextDisplayBox) + text_display_box = widget.findChild(TextDisplayBox) assert isinstance(text_display_box, TextDisplayBox) assert text_display_box.toPlainText( ) == '00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans' def test_should_export_segments(self, tmp_path: pathlib.Path, qtbot: QtBot): - qtbot.add_widget(self.widget) + widget = TranscriptionViewerWidget( + transcription_task=FileTranscriptionTask( + id=0, + file_path='testdata/whisper-french.mp3', + file_transcription_options=FileTranscriptionOptions( + file_paths=['testdata/whisper-french.mp3']), + transcription_options=TranscriptionOptions(), + segments=[Segment(40, 299, 'Bien'), + Segment(299, 329, 'venue dans')], + model_path='')) + qtbot.add_widget(widget) - export_button = self.widget.findChild(QPushButton) + export_button = widget.findChild(QPushButton) assert isinstance(export_button, QPushButton) output_file_path = tmp_path / 'whisper.txt' @@ -291,10 +294,10 @@ class TestTranscriptionViewerWidget: class TestTranscriptionTasksTableWidget: - widget = TranscriptionTasksTableWidget() def test_upsert_task(self, qtbot: QtBot): - qtbot.add_widget(self.widget) + widget = TranscriptionTasksTableWidget() + qtbot.add_widget(widget) task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3', transcription_options=TranscriptionOptions(), @@ -302,28 +305,28 @@ class TestTranscriptionTasksTableWidget: file_paths=['testdata/whisper-french.mp3']), model_path='', status=FileTranscriptionTask.Status.QUEUED) - self.widget.upsert_task(task) + widget.upsert_task(task) - assert self.widget.rowCount() == 1 - assert self.widget.item(0, 1).text() == 'whisper-french.mp3' - assert self.widget.item(0, 2).text() == 'Queued' + assert widget.rowCount() == 1 + assert widget.item(0, 1).text() == 'whisper-french.mp3' + assert widget.item(0, 2).text() == 'Queued' task.status = FileTranscriptionTask.Status.IN_PROGRESS task.fraction_completed = 0.3524 - self.widget.upsert_task(task) + widget.upsert_task(task) - assert self.widget.rowCount() == 1 - assert self.widget.item(0, 1).text() == 'whisper-french.mp3' - assert self.widget.item(0, 2).text() == 'In Progress (35%)' + assert widget.rowCount() == 1 + assert widget.item(0, 1).text() == 'whisper-french.mp3' + assert widget.item(0, 2).text() == 'In Progress (35%)' -@pytest.mark.skip() class TestRecordingTranscriberWidget: def test_should_set_window_title(self, qtbot: QtBot): widget = RecordingTranscriberWidget() qtbot.add_widget(widget) assert widget.windowTitle() == 'Live Recording' + @pytest.mark.skip() def test_should_transcribe(self, qtbot): widget = RecordingTranscriberWidget() qtbot.add_widget(widget)