From cd69363cf8e3bb5aeaf512338de1b6aaac9f3e3d Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Wed, 11 Jan 2023 15:35:29 +0000 Subject: [PATCH] Fix network tests --- buzz/gui.py | 135 ++++++++++++++++++++++++---------------------- tests/gui_test.py | 61 ++++++++++++--------- tests/mock_qt.py | 25 +++++---- 3 files changed, 124 insertions(+), 97 deletions(-) diff --git a/buzz/gui.py b/buzz/gui.py index df0b08f1..f2450b9b 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -227,7 +227,8 @@ class FileTranscriberWidget(QWidget): # (TranscriptionOptions, FileTranscriptionOptions, str) triggered = pyqtSignal(tuple) - def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None, + def __init__(self, network_access_manager: QNetworkAccessManager, file_paths: List[str], + parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowType.Widget) -> None: super().__init__(parent, flags) @@ -240,8 +241,9 @@ class FileTranscriberWidget(QWidget): layout = QVBoxLayout(self) - transcription_options_group_box = TranscriptionOptionsGroupBox( - default_transcription_options=self.transcription_options, parent=self) + transcription_options_group_box = TranscriptionOptionsGroupBox(network_access_manager=network_access_manager, + default_transcription_options=self.transcription_options, + parent=self) transcription_options_group_box.transcription_options_changed.connect( self.on_transcription_options_changed) @@ -474,7 +476,8 @@ class RecordingTranscriberWidget(QWidget): STOPPED = auto() RECORDING = auto() - def __init__(self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowType] = None) -> None: + def __init__(self, network_access_manager: QNetworkAccessManager, parent: Optional[QWidget] = None, + flags: Optional[Qt.WindowType] = None) -> None: super().__init__(parent) if flags is not None: @@ -500,8 +503,9 @@ class RecordingTranscriberWidget(QWidget): self.text_box = TextDisplayBox(self) self.text_box.setPlaceholderText(_('Click Record to begin...')) - transcription_options_group_box = TranscriptionOptionsGroupBox( - default_transcription_options=self.transcription_options, parent=self) + transcription_options_group_box = TranscriptionOptionsGroupBox(network_access_manager=network_access_manager, + default_transcription_options=self.transcription_options, + parent=self) transcription_options_group_box.transcription_options_changed.connect( self.on_transcription_options_changed) @@ -691,8 +695,7 @@ class AboutDialog(QDialog): GITHUB_API_LATEST_RELEASE_URL = 'https://api.github.com/repos/chidiwilliams/buzz/releases/latest' GITHUB_LATEST_RELEASE_URL = 'https://github.com/chidiwilliams/buzz/releases/latest' - def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None, - parent: Optional[QWidget] = None) -> None: + def __init__(self, network_access_manager: QNetworkAccessManager, parent: Optional[QWidget] = None) -> None: super().__init__(parent) self.setFixedSize(200, 250) @@ -700,11 +703,7 @@ class AboutDialog(QDialog): self.setWindowIcon(QIcon(BUZZ_ICON_PATH)) self.setWindowTitle(f'{_("About")} {APP_NAME}') - if network_access_manager is None: - network_access_manager = QNetworkAccessManager() - self.network_access_manager = network_access_manager - self.network_access_manager.finished.connect(self.on_latest_release_reply) layout = QVBoxLayout(self) @@ -745,18 +744,21 @@ class AboutDialog(QDialog): def on_click_check_for_updates(self): url = QUrl(self.GITHUB_API_LATEST_RELEASE_URL) - self.network_access_manager.get(QNetworkRequest(url)) - self.check_updates_button.setDisabled(True) + reply = self.network_access_manager.get(QNetworkRequest(url)) - def on_latest_release_reply(self, reply: QNetworkReply): - if reply.error() == QNetworkReply.NetworkError.NoError: - response = json.loads(reply.readAll().data()) - tag_name = response.get('name') - if self.is_version_lower(VERSION, tag_name[1:]): - QDesktopServices.openUrl(QUrl(self.GITHUB_LATEST_RELEASE_URL)) - else: - QMessageBox.information(self, '', _("You're up to date!")) - self.check_updates_button.setEnabled(True) + def on_reply_finished(): + if reply.error() == QNetworkReply.NetworkError.NoError: + response = json.loads(reply.readAll().data()) + tag_name = response.get('name') + if self.is_version_lower(VERSION, tag_name[1:]): + QDesktopServices.openUrl(QUrl(self.GITHUB_LATEST_RELEASE_URL)) + else: + QMessageBox.information(self, '', _("You're up to date!")) + self.check_updates_button.setEnabled(True) + reply.deleteLater() + + reply.finished.connect(on_reply_finished) + self.check_updates_button.setDisabled(True) @staticmethod def is_version_lower(version_a: str, version_b: str): @@ -847,9 +849,11 @@ class MainWindowToolbar(QToolBar): ICON_LIGHT_THEME_BACKGROUND = '#555' ICON_DARK_THEME_BACKGROUND = '#AAA' - def __init__(self, parent: Optional[QWidget]): + def __init__(self, network_access_manager: QNetworkAccessManager, parent: Optional[QWidget] = None): super().__init__(parent) + self.network_access_manager = network_access_manager + record_action = QAction(self.load_icon(RECORD_ICON_PATH), _('Record'), self) record_action.triggered.connect(self.on_record_action_triggered) @@ -906,7 +910,8 @@ class MainWindowToolbar(QToolBar): return QIcon(pixmap) def on_record_action_triggered(self): - recording_transcriber_window = RecordingTranscriberWidget(self, flags=Qt.WindowType.Window) + recording_transcriber_window = RecordingTranscriberWidget(network_access_manager=self.network_access_manager, + parent=self, flags=Qt.WindowType.Window) recording_transcriber_window.show() def set_stop_transcription_action_enabled(self, enabled: bool): @@ -936,7 +941,9 @@ class MainWindow(QMainWindow): self.tasks = {} self.tasks_changed.connect(self.on_tasks_changed) - self.toolbar = MainWindowToolbar(self) + self.network_access_manager = QNetworkAccessManager() + + self.toolbar = MainWindowToolbar(network_access_manager=self.network_access_manager, parent=self) self.toolbar.new_transcription_action_triggered.connect(self.on_new_transcription_action_triggered) self.toolbar.open_transcript_action_triggered.connect(self.on_open_transcript_action_triggered) self.toolbar.clear_history_action_triggered.connect(self.on_clear_history_action_triggered) @@ -944,7 +951,7 @@ class MainWindow(QMainWindow): self.addToolBar(self.toolbar) self.setUnifiedTitleAndToolBarOnMac(True) - menu_bar = MenuBar(self) + menu_bar = MenuBar(self.network_access_manager, self) menu_bar.import_action_triggered.connect( self.on_new_transcription_action_triggered) self.setMenuBar(menu_bar) @@ -1017,7 +1024,7 @@ class MainWindow(QMainWindow): return file_transcriber_window = FileTranscriberWidget( - file_paths, self, flags=Qt.WindowType.Window) + self.network_access_manager, file_paths, self, flags=Qt.WindowType.Window) file_transcriber_window.triggered.connect( self.on_file_transcriber_triggered) file_transcriber_window.show() @@ -1105,8 +1112,7 @@ class HuggingFaceSearchLineEdit(LineEdit): model_selected = pyqtSignal(str) popup: QListWidget - def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None, - parent: Optional[QWidget] = None): + def __init__(self, network_access_manager: QNetworkAccessManager, parent: Optional[QWidget] = None): super().__init__('', parent) self.setMinimumWidth(150) @@ -1121,11 +1127,7 @@ class HuggingFaceSearchLineEdit(LineEdit): self.textEdited.connect(self.timer.start) self.textEdited.connect(self.on_text_edited) - if network_access_manager is None: - network_access_manager = QNetworkAccessManager(self) - self.network_manager = network_access_manager - self.network_manager.finished.connect(self.on_request_response) self.popup = QListWidget() self.popup.setWindowFlags(Qt.WindowType.Popup) @@ -1161,36 +1163,37 @@ class HuggingFaceSearchLineEdit(LineEdit): url.setQuery(query) - return self.network_manager.get(QNetworkRequest(url)) + reply = self.network_manager.get(QNetworkRequest(url)) + + def on_reply_finished(): + if reply.error() != QNetworkReply.NetworkError.NoError: + logging.debug('Error fetching Hugging Face models: %s', reply.error()) + return + + models = json.loads(reply.readAll().data()) + + self.popup.setUpdatesEnabled(False) + self.popup.clear() + for model in models: + model_id = model.get('id') + item = QListWidgetItem(self.popup) + item.setText(model_id) + item.setData(Qt.ItemDataRole.UserRole, model_id) + + self.popup.setCurrentItem(self.popup.item(0)) + self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20) + self.popup.setFixedHeight( + self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll + self.popup.setUpdatesEnabled(True) + self.popup.move(self.mapToGlobal(QPoint(0, self.height()))) + self.popup.setFocus() + self.popup.show() + + reply.finished.connect(on_reply_finished) def on_popup_selected(self): self.timer.stop() - def on_request_response(self, network_reply: QNetworkReply): - if network_reply.error() != QNetworkReply.NetworkError.NoError: - logging.debug('Error fetching Hugging Face models: %s', network_reply.error()) - return - - models = json.loads(network_reply.readAll().data()) - - self.popup.setUpdatesEnabled(False) - self.popup.clear() - - for model in models: - model_id = model.get('id') - - item = QListWidgetItem(self.popup) - item.setText(model_id) - item.setData(Qt.ItemDataRole.UserRole, model_id) - - self.popup.setCurrentItem(self.popup.item(0)) - self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20) - self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll - self.popup.setUpdatesEnabled(True) - self.popup.move(self.mapToGlobal(QPoint(0, self.height()))) - self.popup.setFocus() - self.popup.show() - def eventFilter(self, target: QObject, event: QEvent): if hasattr(self, 'popup') is False or target != self.popup: return False @@ -1227,7 +1230,8 @@ class TranscriptionOptionsGroupBox(QGroupBox): transcription_options: TranscriptionOptions transcription_options_changed = pyqtSignal(TranscriptionOptions) - def __init__(self, default_transcription_options: TranscriptionOptions = TranscriptionOptions(), + def __init__(self, network_access_manager: QNetworkAccessManager, + default_transcription_options: TranscriptionOptions = TranscriptionOptions(), parent: Optional[QWidget] = None): super().__init__(title='', parent=parent) self.transcription_options = default_transcription_options @@ -1249,7 +1253,8 @@ class TranscriptionOptionsGroupBox(QGroupBox): self.advanced_settings_button.clicked.connect( self.open_advanced_settings) - self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit() + self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit(network_access_manager=network_access_manager, + parent=self) self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed) self.model_type_combo_box = QComboBox(self) @@ -1329,9 +1334,11 @@ class TranscriptionOptionsGroupBox(QGroupBox): class MenuBar(QMenuBar): import_action_triggered = pyqtSignal() - def __init__(self, parent: QWidget): + def __init__(self, network_access_manager: QNetworkAccessManager, parent: QWidget): super().__init__(parent) + self.network_access_manager = network_access_manager + import_action = QAction(_("Import Media File..."), self) import_action.triggered.connect( self.on_import_action_triggered) @@ -1350,7 +1357,7 @@ class MenuBar(QMenuBar): self.import_action_triggered.emit() def on_about_action_triggered(self): - about_dialog = AboutDialog(parent=self) + about_dialog = AboutDialog(parent=self, network_access_manager=self.network_access_manager) about_dialog.open() diff --git a/tests/gui_test.py b/tests/gui_test.py index 97d00c68..6abbf368 100644 --- a/tests/gui_test.py +++ b/tests/gui_test.py @@ -154,7 +154,8 @@ class TestMainWindow: assert open_transcript_action.isEnabled() is False table_widget: QTableWidget = window.findChild(QTableWidget) - qtbot.wait_until(self.assert_task_status(table_widget, 0, 'Completed'), timeout=2 * 60 * 1000) + + qtbot.wait_until(self.assert_task_status(table_widget, 0, 'Completed'), timeout=60 * 1000) table_widget.setCurrentIndex(table_widget.indexFromItem(table_widget.item(0, 1))) assert open_transcript_action.isEnabled() @@ -172,7 +173,7 @@ class TestMainWindow: assert table_widget.item(0, 1).text() == 'whisper-french.mp3' assert 'In Progress' in table_widget.item(0, 2).text() - qtbot.wait_until(assert_task_in_progress, timeout=2 * 60 * 1000) + qtbot.wait_until(assert_task_in_progress, timeout=60 * 1000) # Stop task in progress table_widget.selectRow(0) @@ -234,14 +235,14 @@ class TestMainWindow: class TestFileTranscriberWidget: def test_should_set_window_title(self, qtbot: QtBot): - widget = FileTranscriberWidget( - file_paths=['testdata/whisper-french.mp3'], parent=None) + widget = FileTranscriberWidget(network_access_manager=MockNetworkAccessManager(), + file_paths=['testdata/whisper-french.mp3'], parent=None) qtbot.add_widget(widget) assert widget.windowTitle() == 'whisper-french.mp3' def test_should_emit_triggered_event(self, qtbot: QtBot): - widget = FileTranscriberWidget( - file_paths=['testdata/whisper-french.mp3'], parent=None) + widget = FileTranscriberWidget(network_access_manager=MockNetworkAccessManager(), + file_paths=['testdata/whisper-french.mp3'], parent=None) qtbot.add_widget(widget) mock_triggered = Mock() @@ -268,10 +269,12 @@ class TestAboutDialog: mock_message_box_information = Mock() QMessageBox.information = mock_message_box_information - with qtbot.wait_signal(dialog.network_access_manager.finished): - dialog.check_updates_button.click() + dialog.check_updates_button.click() - mock_message_box_information.assert_called_with(dialog, '', "You're up to date!") + def assert_message(): + mock_message_box_information.assert_called_with(dialog, '', "You're up to date!") + + qtbot.wait_until(assert_message) class TestAdvancedSettingsDialog: @@ -386,12 +389,12 @@ class TestTranscriptionTasksTableWidget: class TestRecordingTranscriberWidget: def test_should_set_window_title(self, qtbot: QtBot): - widget = RecordingTranscriberWidget() + widget = RecordingTranscriberWidget(network_access_manager=MockNetworkAccessManager()) qtbot.add_widget(widget) assert widget.windowTitle() == 'Live Recording' def test_should_transcribe(self, qtbot): - widget = RecordingTranscriberWidget() + widget = RecordingTranscriberWidget(network_access_manager=MockNetworkAccessManager()) qtbot.add_widget(widget) def assert_text_box_contains_text(): @@ -414,17 +417,27 @@ class TestHuggingFaceSearchLineEdit: mock_model_selected = Mock() widget.model_selected.connect(mock_model_selected) - self._set_text_and_wait_response(qtbot, widget) - mock_model_selected.assert_called_with('openai/whisper-tiny') + widget.setText('openai/whisper-tiny') + widget.textEdited.emit('openai/whisper-tiny') + + def assert_model_selected_called(): + mock_model_selected.assert_called() + mock_model_selected.assert_called_with('openai/whisper-tiny') + + qtbot.wait_until(assert_model_selected_called, timeout=30 * 6000) def test_should_show_list_of_models(self, qtbot: QtBot): widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager()) qtbot.add_widget(widget) - self._set_text_and_wait_response(qtbot, widget) + widget.setText('openai/whisper-tiny') + widget.textEdited.emit('openai/whisper-tiny') - assert widget.popup.count() > 0 - assert 'openai/whisper-tiny' in widget.popup.item(0).text() + def assert_popup_item_added(): + assert widget.popup.count() > 0 + assert 'openai/whisper-tiny' in widget.popup.item(0).text() + + qtbot.wait_until(assert_popup_item_added, timeout=30 * 6000) def test_should_select_model_from_list(self, qtbot: QtBot): widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager()) @@ -433,7 +446,13 @@ class TestHuggingFaceSearchLineEdit: mock_model_selected = Mock() widget.model_selected.connect(mock_model_selected) - self._set_text_and_wait_response(qtbot, widget) + widget.setText('openai/whisper-tiny') + widget.textEdited.emit('openai/whisper-tiny') + + def assert_popup_item_added(): + assert widget.popup.count() > 0 + + qtbot.wait_until(assert_popup_item_added, timeout=30 * 6000) # press down arrow and enter to select next item QApplication.sendEvent(widget.popup, @@ -448,16 +467,10 @@ class TestHuggingFaceSearchLineEdit: reply = MockNetworkReply(data=[{'id': 'openai/whisper-tiny'}, {'id': 'openai/whisper-tiny.en'}]) return MockNetworkAccessManager(reply=reply) - @staticmethod - def _set_text_and_wait_response(qtbot: QtBot, widget: HuggingFaceSearchLineEdit): - with qtbot.wait_signal(widget.network_manager.finished): - widget.setText('openai/whisper-tiny') - widget.textEdited.emit('openai/whisper-tiny') - class TestTranscriptionOptionsGroupBox: def test_should_update_model_type(self, qtbot): - widget = TranscriptionOptionsGroupBox() + widget = TranscriptionOptionsGroupBox(network_access_manager=MockNetworkAccessManager()) qtbot.add_widget(widget) mock_transcription_options_changed = Mock() diff --git a/tests/mock_qt.py b/tests/mock_qt.py index 616939f9..1df358b7 100644 --- a/tests/mock_qt.py +++ b/tests/mock_qt.py @@ -1,30 +1,37 @@ - import json +import time +from threading import Thread from typing import Optional -from PyQt6.QtCore import QByteArray, QObject, QSize, Qt, pyqtSignal -from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply +from PyQt6.QtCore import QByteArray, QObject +from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest class MockNetworkReply(QNetworkReply): def __init__(self, data: object, _: Optional[QObject] = None) -> None: + super().__init__() self.data = data - def readAll(self) -> 'QByteArray': + def readAll(self) -> QByteArray: return QByteArray(json.dumps(self.data).encode('utf-8')) - def error(self) -> 'QNetworkReply.NetworkError': + def error(self) -> QNetworkReply.NetworkError: return QNetworkReply.NetworkError.NoError class MockNetworkAccessManager(QNetworkAccessManager): - finished = pyqtSignal(object) reply: MockNetworkReply + reply_thread: Optional[Thread] - def __init__(self, reply: MockNetworkReply, parent: Optional[QObject] = None) -> None: + def __init__(self, reply: Optional[MockNetworkReply] = None, parent: Optional[QObject] = None) -> None: super().__init__(parent) self.reply = reply - def get(self, _: 'QNetworkRequest') -> 'QNetworkReply': - self.finished.emit(self.reply) + def get(self, _: QNetworkRequest) -> QNetworkReply: + def target(): + time.sleep(0.1) + self.reply.finished.emit() + + self.reply_thread = Thread(target=target) + self.reply_thread.start() return self.reply