Fix network tests

This commit is contained in:
Chidi Williams 2023-01-11 15:35:29 +00:00
commit cd69363cf8
3 changed files with 121 additions and 94 deletions

View file

@ -227,7 +227,8 @@ class FileTranscriberWidget(QWidget):
# (TranscriptionOptions, FileTranscriptionOptions, str)
triggered = pyqtSignal(tuple)
def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None,
def __init__(self, network_access_manager: QNetworkAccessManager, file_paths: List[str],
parent: Optional[QWidget] = None,
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
super().__init__(parent, flags)
@ -240,8 +241,9 @@ class FileTranscriberWidget(QWidget):
layout = QVBoxLayout(self)
transcription_options_group_box = TranscriptionOptionsGroupBox(
default_transcription_options=self.transcription_options, parent=self)
transcription_options_group_box = TranscriptionOptionsGroupBox(network_access_manager=network_access_manager,
default_transcription_options=self.transcription_options,
parent=self)
transcription_options_group_box.transcription_options_changed.connect(
self.on_transcription_options_changed)
@ -474,7 +476,8 @@ class RecordingTranscriberWidget(QWidget):
STOPPED = auto()
RECORDING = auto()
def __init__(self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowType] = None) -> None:
def __init__(self, network_access_manager: QNetworkAccessManager, parent: Optional[QWidget] = None,
flags: Optional[Qt.WindowType] = None) -> None:
super().__init__(parent)
if flags is not None:
@ -500,8 +503,9 @@ class RecordingTranscriberWidget(QWidget):
self.text_box = TextDisplayBox(self)
self.text_box.setPlaceholderText(_('Click Record to begin...'))
transcription_options_group_box = TranscriptionOptionsGroupBox(
default_transcription_options=self.transcription_options, parent=self)
transcription_options_group_box = TranscriptionOptionsGroupBox(network_access_manager=network_access_manager,
default_transcription_options=self.transcription_options,
parent=self)
transcription_options_group_box.transcription_options_changed.connect(
self.on_transcription_options_changed)
@ -691,8 +695,7 @@ class AboutDialog(QDialog):
GITHUB_API_LATEST_RELEASE_URL = 'https://api.github.com/repos/chidiwilliams/buzz/releases/latest'
GITHUB_LATEST_RELEASE_URL = 'https://github.com/chidiwilliams/buzz/releases/latest'
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None) -> None:
def __init__(self, network_access_manager: QNetworkAccessManager, parent: Optional[QWidget] = None) -> None:
super().__init__(parent)
self.setFixedSize(200, 250)
@ -700,11 +703,7 @@ class AboutDialog(QDialog):
self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
self.setWindowTitle(f'{_("About")} {APP_NAME}')
if network_access_manager is None:
network_access_manager = QNetworkAccessManager()
self.network_access_manager = network_access_manager
self.network_access_manager.finished.connect(self.on_latest_release_reply)
layout = QVBoxLayout(self)
@ -745,18 +744,21 @@ class AboutDialog(QDialog):
def on_click_check_for_updates(self):
url = QUrl(self.GITHUB_API_LATEST_RELEASE_URL)
self.network_access_manager.get(QNetworkRequest(url))
self.check_updates_button.setDisabled(True)
reply = self.network_access_manager.get(QNetworkRequest(url))
def on_latest_release_reply(self, reply: QNetworkReply):
if reply.error() == QNetworkReply.NetworkError.NoError:
response = json.loads(reply.readAll().data())
tag_name = response.get('name')
if self.is_version_lower(VERSION, tag_name[1:]):
QDesktopServices.openUrl(QUrl(self.GITHUB_LATEST_RELEASE_URL))
else:
QMessageBox.information(self, '', _("You're up to date!"))
self.check_updates_button.setEnabled(True)
def on_reply_finished():
if reply.error() == QNetworkReply.NetworkError.NoError:
response = json.loads(reply.readAll().data())
tag_name = response.get('name')
if self.is_version_lower(VERSION, tag_name[1:]):
QDesktopServices.openUrl(QUrl(self.GITHUB_LATEST_RELEASE_URL))
else:
QMessageBox.information(self, '', _("You're up to date!"))
self.check_updates_button.setEnabled(True)
reply.deleteLater()
reply.finished.connect(on_reply_finished)
self.check_updates_button.setDisabled(True)
@staticmethod
def is_version_lower(version_a: str, version_b: str):
@ -847,9 +849,11 @@ class MainWindowToolbar(QToolBar):
ICON_LIGHT_THEME_BACKGROUND = '#555'
ICON_DARK_THEME_BACKGROUND = '#AAA'
def __init__(self, parent: Optional[QWidget]):
def __init__(self, network_access_manager: QNetworkAccessManager, parent: Optional[QWidget] = None):
super().__init__(parent)
self.network_access_manager = network_access_manager
record_action = QAction(self.load_icon(RECORD_ICON_PATH), _('Record'), self)
record_action.triggered.connect(self.on_record_action_triggered)
@ -906,7 +910,8 @@ class MainWindowToolbar(QToolBar):
return QIcon(pixmap)
def on_record_action_triggered(self):
recording_transcriber_window = RecordingTranscriberWidget(self, flags=Qt.WindowType.Window)
recording_transcriber_window = RecordingTranscriberWidget(network_access_manager=self.network_access_manager,
parent=self, flags=Qt.WindowType.Window)
recording_transcriber_window.show()
def set_stop_transcription_action_enabled(self, enabled: bool):
@ -936,7 +941,9 @@ class MainWindow(QMainWindow):
self.tasks = {}
self.tasks_changed.connect(self.on_tasks_changed)
self.toolbar = MainWindowToolbar(self)
self.network_access_manager = QNetworkAccessManager()
self.toolbar = MainWindowToolbar(network_access_manager=self.network_access_manager, parent=self)
self.toolbar.new_transcription_action_triggered.connect(self.on_new_transcription_action_triggered)
self.toolbar.open_transcript_action_triggered.connect(self.on_open_transcript_action_triggered)
self.toolbar.clear_history_action_triggered.connect(self.on_clear_history_action_triggered)
@ -944,7 +951,7 @@ class MainWindow(QMainWindow):
self.addToolBar(self.toolbar)
self.setUnifiedTitleAndToolBarOnMac(True)
menu_bar = MenuBar(self)
menu_bar = MenuBar(self.network_access_manager, self)
menu_bar.import_action_triggered.connect(
self.on_new_transcription_action_triggered)
self.setMenuBar(menu_bar)
@ -1017,7 +1024,7 @@ class MainWindow(QMainWindow):
return
file_transcriber_window = FileTranscriberWidget(
file_paths, self, flags=Qt.WindowType.Window)
self.network_access_manager, file_paths, self, flags=Qt.WindowType.Window)
file_transcriber_window.triggered.connect(
self.on_file_transcriber_triggered)
file_transcriber_window.show()
@ -1105,8 +1112,7 @@ class HuggingFaceSearchLineEdit(LineEdit):
model_selected = pyqtSignal(str)
popup: QListWidget
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None):
def __init__(self, network_access_manager: QNetworkAccessManager, parent: Optional[QWidget] = None):
super().__init__('', parent)
self.setMinimumWidth(150)
@ -1121,11 +1127,7 @@ class HuggingFaceSearchLineEdit(LineEdit):
self.textEdited.connect(self.timer.start)
self.textEdited.connect(self.on_text_edited)
if network_access_manager is None:
network_access_manager = QNetworkAccessManager(self)
self.network_manager = network_access_manager
self.network_manager.finished.connect(self.on_request_response)
self.popup = QListWidget()
self.popup.setWindowFlags(Qt.WindowType.Popup)
@ -1161,36 +1163,37 @@ class HuggingFaceSearchLineEdit(LineEdit):
url.setQuery(query)
return self.network_manager.get(QNetworkRequest(url))
reply = self.network_manager.get(QNetworkRequest(url))
def on_reply_finished():
if reply.error() != QNetworkReply.NetworkError.NoError:
logging.debug('Error fetching Hugging Face models: %s', reply.error())
return
models = json.loads(reply.readAll().data())
self.popup.setUpdatesEnabled(False)
self.popup.clear()
for model in models:
model_id = model.get('id')
item = QListWidgetItem(self.popup)
item.setText(model_id)
item.setData(Qt.ItemDataRole.UserRole, model_id)
self.popup.setCurrentItem(self.popup.item(0))
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
self.popup.setFixedHeight(
self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
self.popup.setUpdatesEnabled(True)
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
self.popup.setFocus()
self.popup.show()
reply.finished.connect(on_reply_finished)
def on_popup_selected(self):
self.timer.stop()
def on_request_response(self, network_reply: QNetworkReply):
if network_reply.error() != QNetworkReply.NetworkError.NoError:
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
return
models = json.loads(network_reply.readAll().data())
self.popup.setUpdatesEnabled(False)
self.popup.clear()
for model in models:
model_id = model.get('id')
item = QListWidgetItem(self.popup)
item.setText(model_id)
item.setData(Qt.ItemDataRole.UserRole, model_id)
self.popup.setCurrentItem(self.popup.item(0))
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
self.popup.setUpdatesEnabled(True)
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
self.popup.setFocus()
self.popup.show()
def eventFilter(self, target: QObject, event: QEvent):
if hasattr(self, 'popup') is False or target != self.popup:
return False
@ -1227,7 +1230,8 @@ class TranscriptionOptionsGroupBox(QGroupBox):
transcription_options: TranscriptionOptions
transcription_options_changed = pyqtSignal(TranscriptionOptions)
def __init__(self, default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
def __init__(self, network_access_manager: QNetworkAccessManager,
default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
parent: Optional[QWidget] = None):
super().__init__(title='', parent=parent)
self.transcription_options = default_transcription_options
@ -1249,7 +1253,8 @@ class TranscriptionOptionsGroupBox(QGroupBox):
self.advanced_settings_button.clicked.connect(
self.open_advanced_settings)
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit(network_access_manager=network_access_manager,
parent=self)
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)
self.model_type_combo_box = QComboBox(self)
@ -1329,9 +1334,11 @@ class TranscriptionOptionsGroupBox(QGroupBox):
class MenuBar(QMenuBar):
import_action_triggered = pyqtSignal()
def __init__(self, parent: QWidget):
def __init__(self, network_access_manager: QNetworkAccessManager, parent: QWidget):
super().__init__(parent)
self.network_access_manager = network_access_manager
import_action = QAction(_("Import Media File..."), self)
import_action.triggered.connect(
self.on_import_action_triggered)
@ -1350,7 +1357,7 @@ class MenuBar(QMenuBar):
self.import_action_triggered.emit()
def on_about_action_triggered(self):
about_dialog = AboutDialog(parent=self)
about_dialog = AboutDialog(parent=self, network_access_manager=self.network_access_manager)
about_dialog.open()

View file

@ -154,7 +154,8 @@ class TestMainWindow:
assert open_transcript_action.isEnabled() is False
table_widget: QTableWidget = window.findChild(QTableWidget)
qtbot.wait_until(self.assert_task_status(table_widget, 0, 'Completed'), timeout=2 * 60 * 1000)
qtbot.wait_until(self.assert_task_status(table_widget, 0, 'Completed'), timeout=60 * 1000)
table_widget.setCurrentIndex(table_widget.indexFromItem(table_widget.item(0, 1)))
assert open_transcript_action.isEnabled()
@ -172,7 +173,7 @@ class TestMainWindow:
assert table_widget.item(0, 1).text() == 'whisper-french.mp3'
assert 'In Progress' in table_widget.item(0, 2).text()
qtbot.wait_until(assert_task_in_progress, timeout=2 * 60 * 1000)
qtbot.wait_until(assert_task_in_progress, timeout=60 * 1000)
# Stop task in progress
table_widget.selectRow(0)
@ -234,14 +235,14 @@ class TestMainWindow:
class TestFileTranscriberWidget:
def test_should_set_window_title(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], parent=None)
widget = FileTranscriberWidget(network_access_manager=MockNetworkAccessManager(),
file_paths=['testdata/whisper-french.mp3'], parent=None)
qtbot.add_widget(widget)
assert widget.windowTitle() == 'whisper-french.mp3'
def test_should_emit_triggered_event(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], parent=None)
widget = FileTranscriberWidget(network_access_manager=MockNetworkAccessManager(),
file_paths=['testdata/whisper-french.mp3'], parent=None)
qtbot.add_widget(widget)
mock_triggered = Mock()
@ -268,10 +269,12 @@ class TestAboutDialog:
mock_message_box_information = Mock()
QMessageBox.information = mock_message_box_information
with qtbot.wait_signal(dialog.network_access_manager.finished):
dialog.check_updates_button.click()
dialog.check_updates_button.click()
mock_message_box_information.assert_called_with(dialog, '', "You're up to date!")
def assert_message():
mock_message_box_information.assert_called_with(dialog, '', "You're up to date!")
qtbot.wait_until(assert_message)
class TestAdvancedSettingsDialog:
@ -386,12 +389,12 @@ class TestTranscriptionTasksTableWidget:
class TestRecordingTranscriberWidget:
def test_should_set_window_title(self, qtbot: QtBot):
widget = RecordingTranscriberWidget()
widget = RecordingTranscriberWidget(network_access_manager=MockNetworkAccessManager())
qtbot.add_widget(widget)
assert widget.windowTitle() == 'Live Recording'
def test_should_transcribe(self, qtbot):
widget = RecordingTranscriberWidget()
widget = RecordingTranscriberWidget(network_access_manager=MockNetworkAccessManager())
qtbot.add_widget(widget)
def assert_text_box_contains_text():
@ -414,17 +417,27 @@ class TestHuggingFaceSearchLineEdit:
mock_model_selected = Mock()
widget.model_selected.connect(mock_model_selected)
self._set_text_and_wait_response(qtbot, widget)
mock_model_selected.assert_called_with('openai/whisper-tiny')
widget.setText('openai/whisper-tiny')
widget.textEdited.emit('openai/whisper-tiny')
def assert_model_selected_called():
mock_model_selected.assert_called()
mock_model_selected.assert_called_with('openai/whisper-tiny')
qtbot.wait_until(assert_model_selected_called, timeout=30 * 6000)
def test_should_show_list_of_models(self, qtbot: QtBot):
widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager())
qtbot.add_widget(widget)
self._set_text_and_wait_response(qtbot, widget)
widget.setText('openai/whisper-tiny')
widget.textEdited.emit('openai/whisper-tiny')
assert widget.popup.count() > 0
assert 'openai/whisper-tiny' in widget.popup.item(0).text()
def assert_popup_item_added():
assert widget.popup.count() > 0
assert 'openai/whisper-tiny' in widget.popup.item(0).text()
qtbot.wait_until(assert_popup_item_added, timeout=30 * 6000)
def test_should_select_model_from_list(self, qtbot: QtBot):
widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager())
@ -433,7 +446,13 @@ class TestHuggingFaceSearchLineEdit:
mock_model_selected = Mock()
widget.model_selected.connect(mock_model_selected)
self._set_text_and_wait_response(qtbot, widget)
widget.setText('openai/whisper-tiny')
widget.textEdited.emit('openai/whisper-tiny')
def assert_popup_item_added():
assert widget.popup.count() > 0
qtbot.wait_until(assert_popup_item_added, timeout=30 * 6000)
# press down arrow and enter to select next item
QApplication.sendEvent(widget.popup,
@ -448,16 +467,10 @@ class TestHuggingFaceSearchLineEdit:
reply = MockNetworkReply(data=[{'id': 'openai/whisper-tiny'}, {'id': 'openai/whisper-tiny.en'}])
return MockNetworkAccessManager(reply=reply)
@staticmethod
def _set_text_and_wait_response(qtbot: QtBot, widget: HuggingFaceSearchLineEdit):
with qtbot.wait_signal(widget.network_manager.finished):
widget.setText('openai/whisper-tiny')
widget.textEdited.emit('openai/whisper-tiny')
class TestTranscriptionOptionsGroupBox:
def test_should_update_model_type(self, qtbot):
widget = TranscriptionOptionsGroupBox()
widget = TranscriptionOptionsGroupBox(network_access_manager=MockNetworkAccessManager())
qtbot.add_widget(widget)
mock_transcription_options_changed = Mock()

View file

@ -1,30 +1,37 @@
import json
import time
from threading import Thread
from typing import Optional
from PyQt6.QtCore import QByteArray, QObject, QSize, Qt, pyqtSignal
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply
from PyQt6.QtCore import QByteArray, QObject
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
class MockNetworkReply(QNetworkReply):
def __init__(self, data: object, _: Optional[QObject] = None) -> None:
super().__init__()
self.data = data
def readAll(self) -> 'QByteArray':
def readAll(self) -> QByteArray:
return QByteArray(json.dumps(self.data).encode('utf-8'))
def error(self) -> 'QNetworkReply.NetworkError':
def error(self) -> QNetworkReply.NetworkError:
return QNetworkReply.NetworkError.NoError
class MockNetworkAccessManager(QNetworkAccessManager):
finished = pyqtSignal(object)
reply: MockNetworkReply
reply_thread: Optional[Thread]
def __init__(self, reply: MockNetworkReply, parent: Optional[QObject] = None) -> None:
def __init__(self, reply: Optional[MockNetworkReply] = None, parent: Optional[QObject] = None) -> None:
super().__init__(parent)
self.reply = reply
def get(self, _: 'QNetworkRequest') -> 'QNetworkReply':
self.finished.emit(self.reply)
def get(self, _: QNetworkRequest) -> QNetworkReply:
def target():
time.sleep(0.1)
self.reply.finished.emit()
self.reply_thread = Thread(target=target)
self.reply_thread.start()
return self.reply