Fix intermittently failing tests (#277)

This commit is contained in:
Chidi Williams 2022-12-31 09:13:52 +00:00 committed by GitHub
parent cfcb4f6c28
commit f6ef2d5fe3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 69 deletions

View file

@ -7,4 +7,4 @@ omit =
directory = coverage/html
[report]
fail_under = 70
fail_under = 74

View file

@ -119,9 +119,9 @@ class LanguagesComboBox(QComboBox):
self.currentIndexChanged.connect(self.on_index_changed)
default_language_key = default_language if default_language != '' else None
default_language_index = next((i for i, lang in enumerate(self.languages)
if lang[0] == default_language_key), 0)
self.setCurrentIndex(default_language_index)
for i, lang in enumerate(self.languages):
if lang[0] == default_language_key:
self.setCurrentIndex(i)
def on_index_changed(self, index: int):
self.languageChanged.emit(self.languages[index][0])
@ -361,6 +361,11 @@ class FileTranscriberWidget(QWidget):
def on_word_level_timings_changed(self, value: int):
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
if self.transcriber_thread is not None:
self.transcriber_thread.wait()
super().closeEvent(event)
class TranscriptionViewerWidget(QWidget):
transcription_task: FileTranscriptionTask

View file

@ -70,9 +70,8 @@ class ModelLoader(QObject):
file_path = os.path.join(root_dir, f'ggml-model-whisper-{model_name}.bin')
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
self.download_model(url, file_path, expected_sha256)
return
if self.model_type == ModelType.WHISPER:
elif self.model_type == ModelType.WHISPER:
root_dir = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
@ -82,9 +81,8 @@ class ModelLoader(QObject):
file_path = os.path.join(root_dir, os.path.basename(url))
expected_sha256 = url.split('/')[-2]
self.download_model(url, file_path, expected_sha256)
return
if self.model_type == ModelType.HUGGING_FACE:
else: # ModelType.HUGGING_FACE:
self.progress.emit((0, 100))
try:
@ -95,8 +93,9 @@ class ModelLoader(QObject):
return
self.progress.emit((100, 100))
self.finished.emit(self.hugging_face_model_id)
return
file_path = self.hugging_face_model_id
self.finished.emit(file_path)
def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]):
try:
@ -108,14 +107,12 @@ class ModelLoader(QObject):
if os.path.isfile(file_path):
if expected_sha256 is None:
self.finished.emit(file_path)
return
return file_path
model_bytes = open(file_path, "rb").read()
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
if model_sha256 == expected_sha256:
self.finished.emit(file_path)
return
return file_path
else:
warnings.warn(
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
@ -141,7 +138,7 @@ class ModelLoader(QObject):
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
"model.")
self.finished.emit(file_path)
return file_path
except RuntimeError as exc:
self.error.emit(str(exc))
logging.exception('')

View file

@ -1,7 +1,6 @@
import logging
import os.path
import pathlib
import platform
from unittest.mock import Mock, patch
import pytest
@ -12,8 +11,7 @@ from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget, QApplication
from pytestqt.qtbot import QtBot
from buzz.cache import TasksCache
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
AudioDevicesComboBox, DownloadModelProgressDialog,
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, AudioDevicesComboBox, DownloadModelProgressDialog,
FileTranscriberWidget, LanguagesComboBox, MainWindow,
RecordingTranscriberWidget,
TemperatureValidator, TextDisplayBox,
@ -25,30 +23,26 @@ from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
from tests.mock_sounddevice import MockInputStream
class TestApplication:
# FIXME: this seems to break the tests if not run??
app = Application()
def test_should_open_application(self):
assert self.app is not None
class TestLanguagesComboBox:
languagesComboxBox = LanguagesComboBox('en')
def test_should_show_sorted_whisper_languages(self):
assert self.languagesComboxBox.itemText(0) == 'Detect Language'
assert self.languagesComboxBox.itemText(10) == 'Belarusian'
assert self.languagesComboxBox.itemText(20) == 'Dutch'
assert self.languagesComboxBox.itemText(30) == 'Gujarati'
assert self.languagesComboxBox.itemText(40) == 'Japanese'
assert self.languagesComboxBox.itemText(50) == 'Lithuanian'
def test_should_show_sorted_whisper_languages(self, qtbot):
languages_combox_box = LanguagesComboBox('en')
qtbot.add_widget(languages_combox_box)
assert languages_combox_box.itemText(0) == 'Detect Language'
assert languages_combox_box.itemText(10) == 'Belarusian'
assert languages_combox_box.itemText(20) == 'Dutch'
assert languages_combox_box.itemText(30) == 'Gujarati'
assert languages_combox_box.itemText(40) == 'Japanese'
assert languages_combox_box.itemText(50) == 'Lithuanian'
def test_should_select_en_as_default_language(self):
assert self.languagesComboxBox.currentText() == 'English'
def test_should_select_en_as_default_language(self, qtbot):
languages_combox_box = LanguagesComboBox('en')
qtbot.add_widget(languages_combox_box)
assert languages_combox_box.currentText() == 'English'
def test_should_select_detect_language_as_default(self):
def test_should_select_detect_language_as_default(self, qtbot):
languages_combo_box = LanguagesComboBox(None)
qtbot.add_widget(languages_combo_box)
assert languages_combo_box.currentText() == 'Detect Language'
@ -185,17 +179,16 @@ class TestMainWindow:
class TestFileTranscriberWidget:
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], parent=None)
def test_should_set_window_title(self, qtbot: QtBot):
qtbot.addWidget(self.widget)
assert self.widget.windowTitle() == 'whisper-french.mp3'
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], parent=None)
qtbot.add_widget(widget)
assert widget.windowTitle() == 'whisper-french.mp3'
def test_should_emit_triggered_event(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], parent=None)
qtbot.addWidget(widget)
qtbot.add_widget(widget)
mock_triggered = Mock()
widget.triggered.connect(mock_triggered)
@ -254,31 +247,41 @@ class TestTemperatureValidator:
class TestTranscriptionViewerWidget:
widget = TranscriptionViewerWidget(
transcription_task=FileTranscriptionTask(
id=0,
file_path='testdata/whisper-french.mp3',
file_transcription_options=FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3']),
transcription_options=TranscriptionOptions(),
segments=[Segment(40, 299, 'Bien'),
Segment(299, 329, 'venue dans')],
model_path=''))
def test_should_display_segments(self, qtbot: QtBot):
qtbot.add_widget(self.widget)
widget = TranscriptionViewerWidget(
transcription_task=FileTranscriptionTask(
id=0,
file_path='testdata/whisper-french.mp3',
file_transcription_options=FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3']),
transcription_options=TranscriptionOptions(),
segments=[Segment(40, 299, 'Bien'),
Segment(299, 329, 'venue dans')],
model_path=''))
qtbot.add_widget(widget)
assert self.widget.windowTitle() == 'whisper-french.mp3'
assert widget.windowTitle() == 'whisper-french.mp3'
text_display_box = self.widget.findChild(TextDisplayBox)
text_display_box = widget.findChild(TextDisplayBox)
assert isinstance(text_display_box, TextDisplayBox)
assert text_display_box.toPlainText(
) == '00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans'
def test_should_export_segments(self, tmp_path: pathlib.Path, qtbot: QtBot):
qtbot.add_widget(self.widget)
widget = TranscriptionViewerWidget(
transcription_task=FileTranscriptionTask(
id=0,
file_path='testdata/whisper-french.mp3',
file_transcription_options=FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3']),
transcription_options=TranscriptionOptions(),
segments=[Segment(40, 299, 'Bien'),
Segment(299, 329, 'venue dans')],
model_path=''))
qtbot.add_widget(widget)
export_button = self.widget.findChild(QPushButton)
export_button = widget.findChild(QPushButton)
assert isinstance(export_button, QPushButton)
output_file_path = tmp_path / 'whisper.txt'
@ -291,10 +294,10 @@ class TestTranscriptionViewerWidget:
class TestTranscriptionTasksTableWidget:
widget = TranscriptionTasksTableWidget()
def test_upsert_task(self, qtbot: QtBot):
qtbot.add_widget(self.widget)
widget = TranscriptionTasksTableWidget()
qtbot.add_widget(widget)
task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3',
transcription_options=TranscriptionOptions(),
@ -302,28 +305,28 @@ class TestTranscriptionTasksTableWidget:
file_paths=['testdata/whisper-french.mp3']), model_path='',
status=FileTranscriptionTask.Status.QUEUED)
self.widget.upsert_task(task)
widget.upsert_task(task)
assert self.widget.rowCount() == 1
assert self.widget.item(0, 1).text() == 'whisper-french.mp3'
assert self.widget.item(0, 2).text() == 'Queued'
assert widget.rowCount() == 1
assert widget.item(0, 1).text() == 'whisper-french.mp3'
assert widget.item(0, 2).text() == 'Queued'
task.status = FileTranscriptionTask.Status.IN_PROGRESS
task.fraction_completed = 0.3524
self.widget.upsert_task(task)
widget.upsert_task(task)
assert self.widget.rowCount() == 1
assert self.widget.item(0, 1).text() == 'whisper-french.mp3'
assert self.widget.item(0, 2).text() == 'In Progress (35%)'
assert widget.rowCount() == 1
assert widget.item(0, 1).text() == 'whisper-french.mp3'
assert widget.item(0, 2).text() == 'In Progress (35%)'
@pytest.mark.skip()
class TestRecordingTranscriberWidget:
def test_should_set_window_title(self, qtbot: QtBot):
widget = RecordingTranscriberWidget()
qtbot.add_widget(widget)
assert widget.windowTitle() == 'Live Recording'
@pytest.mark.skip()
def test_should_transcribe(self, qtbot):
widget = RecordingTranscriberWidget()
qtbot.add_widget(widget)