From f5f77b3908954cafd0218700052ad964487a850f Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Fri, 4 Aug 2023 18:02:20 -0700 Subject: [PATCH] Add default file name setting (#559) --- buzz/dialogs.py | 10 + buzz/gui.py | 563 +----------------- buzz/model_loader.py | 2 +- buzz/settings/settings.py | 8 +- buzz/transcriber.py | 93 ++- buzz/widgets/menu_bar.py | 19 +- .../general_preferences_widget.py | 22 +- .../preferences_dialog/preferences_dialog.py | 8 +- buzz/widgets/transcriber/__init__.py | 0 .../transcriber/advanced_settings_button.py | 8 + .../transcriber/advanced_settings_dialog.py | 64 ++ .../transcriber/file_transcriber_widget.py | 204 +++++++ .../hugging_face_search_line_edit.py | 134 +++++ .../transcriber/languages_combo_box.py | 31 + buzz/widgets/transcriber/tasks_combo_box.py | 21 + .../transcriber/temperature_validator.py | 19 + .../transcription_options_group_box.py | 140 +++++ buzz/widgets/transcription_viewer_widget.py | 6 +- tests/gui_test.py | 41 +- tests/transcriber_test.py | 29 +- tests/widgets/file_transcriber_widget_test.py | 32 + .../general_preferences_widget_test.py | 7 +- .../preferences_dialog_test.py | 2 +- 23 files changed, 843 insertions(+), 620 deletions(-) create mode 100644 buzz/dialogs.py create mode 100644 buzz/widgets/transcriber/__init__.py create mode 100644 buzz/widgets/transcriber/advanced_settings_button.py create mode 100644 buzz/widgets/transcriber/advanced_settings_dialog.py create mode 100644 buzz/widgets/transcriber/file_transcriber_widget.py create mode 100644 buzz/widgets/transcriber/hugging_face_search_line_edit.py create mode 100644 buzz/widgets/transcriber/languages_combo_box.py create mode 100644 buzz/widgets/transcriber/tasks_combo_box.py create mode 100644 buzz/widgets/transcriber/temperature_validator.py create mode 100644 buzz/widgets/transcriber/transcription_options_group_box.py create mode 100644 tests/widgets/file_transcriber_widget_test.py diff --git a/buzz/dialogs.py b/buzz/dialogs.py new file mode 100644 index 00000000..b1b7ff59 --- /dev/null +++ b/buzz/dialogs.py @@ -0,0 +1,10 @@ +from PyQt6.QtWidgets import QWidget, QMessageBox + + +def show_model_download_error_dialog(parent: QWidget, error: str): + message = parent.tr( + 'An error occurred while loading the Whisper model') + \ + f": {error}{'' if error.endswith('.') else '.'}" + \ + parent.tr("Please retry or check the application logs for more information.") + + QMessageBox.critical(parent, '', message) diff --git a/buzz/gui.py b/buzz/gui.py index f36de249..496d024a 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -1,51 +1,44 @@ import enum -import json -import logging import sys from enum import auto from typing import Dict, List, Optional, Tuple import sounddevice from PyQt6 import QtGui -from PyQt6.QtCore import (QObject, Qt, QThread, - QTimer, QUrl, pyqtSignal, QModelIndex, QPoint, - QUrlQuery, QMetaObject, QEvent, QThreadPool) +from PyQt6.QtCore import (Qt, QThread, + pyqtSignal, QModelIndex, QThreadPool) from PyQt6.QtGui import (QCloseEvent, QIcon, - QKeySequence, QTextCursor, QValidator, QKeyEvent, QPainter, QColor) -from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest -from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog, - QDialogButtonBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit, - QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QGroupBox, - QFormLayout, - QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy) + QKeySequence, QTextCursor, QPainter, QColor) +from PyQt6.QtWidgets import (QApplication, QComboBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit, + QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QFormLayout, + QSizePolicy) from buzz.cache import TasksCache from .__version__ import VERSION from .action import Action from .assets import get_asset_path +from .dialogs import show_model_download_error_dialog from .widgets.icon import Icon, BUZZ_ICON_PATH from .locale import _ from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, \ ModelDownloader -from .paths import file_paths_as_title from .recording import RecordingAmplitudeListener from .settings.settings import Settings, APP_NAME from .settings.shortcut import Shortcut from .settings.shortcut_settings import ShortcutSettings from .store.keyring_store import KeyringStore -from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat, - Task, +from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Task, TranscriptionOptions, FileTranscriptionTask, LOADED_WHISPER_DLL, - DEFAULT_WHISPER_TEMPERATURE, LANGUAGES) + DEFAULT_WHISPER_TEMPERATURE) from .recording_transcriber import RecordingTranscriber from .file_transcriber_queue_worker import FileTranscriberQueueWorker -from .widgets.line_edit import LineEdit from .widgets.menu_bar import MenuBar from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog -from .widgets.model_type_combo_box import ModelTypeComboBox -from .widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit from .widgets.toolbar import ToolBar +from .widgets.transcriber.file_transcriber_widget import FileTranscriberWidget +from .widgets.transcriber.transcription_options_group_box import \ + TranscriptionOptionsGroupBox from .widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget from .widgets.transcription_viewer_widget import TranscriptionViewerWidget @@ -102,45 +95,6 @@ class AudioDevicesComboBox(QComboBox): return -1 -class LanguagesComboBox(QComboBox): - """LanguagesComboBox displays a list of languages available to use with Whisper""" - # language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language" - languageChanged = pyqtSignal(str) - - def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None: - super().__init__(parent) - - whisper_languages = sorted( - [(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1]) - self.languages = [('', _('Detect Language'))] + whisper_languages - - self.addItems([lang[1] for lang in self.languages]) - self.currentIndexChanged.connect(self.on_index_changed) - - default_language_key = default_language if default_language != '' else None - for i, lang in enumerate(self.languages): - if lang[0] == default_language_key: - self.setCurrentIndex(i) - - def on_index_changed(self, index: int): - self.languageChanged.emit(self.languages[index][0]) - - -class TasksComboBox(QComboBox): - """TasksComboBox displays a list of tasks available to use with Whisper""" - taskChanged = pyqtSignal(Task) - - def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None: - super().__init__(parent, *args) - self.tasks = [i for i in Task] - self.addItems(map(lambda task: task.value.title(), self.tasks)) - self.currentIndexChanged.connect(self.on_index_changed) - self.setCurrentText(default_task.value.title()) - - def on_index_changed(self, index: int): - self.taskChanged.emit(self.tasks[index]) - - class TextDisplayBox(QPlainTextEdit): """TextDisplayBox is a read-only textbox""" @@ -164,181 +118,6 @@ class RecordButton(QPushButton): self.setDefault(False) -def show_model_download_error_dialog(parent: QWidget, error: str): - message = parent.tr( - 'An error occurred while loading the Whisper model') + \ - f": {error}{'' if error.endswith('.') else '.'}" + \ - parent.tr("Please retry or check the application logs for more information.") - - QMessageBox.critical(parent, '', message) - - -class FileTranscriberWidget(QWidget): - model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None - model_loader: Optional[ModelDownloader] = None - file_transcription_options: FileTranscriptionOptions - transcription_options: TranscriptionOptions - is_transcribing = False - # (TranscriptionOptions, FileTranscriptionOptions, str) - triggered = pyqtSignal(tuple) - openai_access_token_changed = pyqtSignal(str) - settings = Settings() - - def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None, - flags: Qt.WindowType = Qt.WindowType.Widget) -> None: - super().__init__(parent, flags) - - self.setWindowTitle(file_paths_as_title(file_paths)) - - openai_access_token = KeyringStore().get_password(KeyringStore.Key.OPENAI_API_KEY) - - self.file_paths = file_paths - default_language = self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='') - self.transcription_options = TranscriptionOptions( - openai_access_token=openai_access_token, - model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL, default_value=TranscriptionModel()), - task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE), - language=default_language if default_language != '' else None, - initial_prompt=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''), - temperature=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, - default_value=DEFAULT_WHISPER_TEMPERATURE), - word_level_timings=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, - default_value=False)) - default_export_format_states: List[str] = self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, - default_value=[]) - self.file_transcription_options = FileTranscriptionOptions( - file_paths=self.file_paths, - output_formats=set([OutputFormat(output_format) for output_format in default_export_format_states])) - - layout = QVBoxLayout(self) - - transcription_options_group_box = TranscriptionOptionsGroupBox( - default_transcription_options=self.transcription_options, parent=self) - transcription_options_group_box.transcription_options_changed.connect( - self.on_transcription_options_changed) - - self.word_level_timings_checkbox = QCheckBox(_('Word-level timings')) - self.word_level_timings_checkbox.setChecked( - self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, default_value=False)) - self.word_level_timings_checkbox.stateChanged.connect( - self.on_word_level_timings_changed) - - file_transcription_layout = QFormLayout() - file_transcription_layout.addRow('', self.word_level_timings_checkbox) - - export_format_layout = QHBoxLayout() - for output_format in OutputFormat: - export_format_checkbox = QCheckBox(f'{output_format.value.upper()}', parent=self) - export_format_checkbox.setChecked(output_format in self.file_transcription_options.output_formats) - export_format_checkbox.stateChanged.connect(self.get_on_checkbox_state_changed_callback(output_format)) - export_format_layout.addWidget(export_format_checkbox) - - file_transcription_layout.addRow('Export:', export_format_layout) - - self.run_button = QPushButton(_('Run'), self) - self.run_button.setDefault(True) - self.run_button.clicked.connect(self.on_click_run) - - layout.addWidget(transcription_options_group_box) - layout.addLayout(file_transcription_layout) - layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight) - - self.setLayout(layout) - self.setFixedSize(self.sizeHint()) - - def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat): - def on_checkbox_state_changed(state: int): - if state == Qt.CheckState.Checked.value: - self.file_transcription_options.output_formats.add(output_format) - elif state == Qt.CheckState.Unchecked.value: - self.file_transcription_options.output_formats.remove(output_format) - - return on_checkbox_state_changed - - def on_transcription_options_changed(self, transcription_options: TranscriptionOptions): - self.transcription_options = transcription_options - self.word_level_timings_checkbox.setDisabled( - self.transcription_options.model.model_type == ModelType.HUGGING_FACE or self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API) - if self.transcription_options.openai_access_token != '': - self.openai_access_token_changed.emit(self.transcription_options.openai_access_token) - - def on_click_run(self): - self.run_button.setDisabled(True) - - model_path = self.transcription_options.model.get_local_model_path() - if model_path is not None: - self.on_model_loaded(model_path) - return - - self.model_loader = ModelDownloader(model=self.transcription_options.model) - self.model_loader.signals.progress.connect(self.on_download_model_progress) - self.model_loader.signals.error.connect(self.on_download_model_error) - self.model_loader.signals.finished.connect(self.on_model_loaded) - QThreadPool().globalInstance().start(self.model_loader) - - def on_model_loaded(self, model_path: str): - self.reset_transcriber_controls() - - self.triggered.emit((self.transcription_options, - self.file_transcription_options, model_path)) - self.close() - - def on_download_model_progress(self, progress: Tuple[float, float]): - (current_size, total_size) = progress - - if self.model_download_progress_dialog is None: - self.model_download_progress_dialog = ModelDownloadProgressDialog( - model_type=self.transcription_options.model.model_type, parent=self) - self.model_download_progress_dialog.canceled.connect( - self.on_cancel_model_progress_dialog) - - if self.model_download_progress_dialog is not None: - self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size) - - def on_download_model_error(self, error: str): - self.reset_model_download() - show_model_download_error_dialog(self, error) - self.reset_transcriber_controls() - - def reset_transcriber_controls(self): - self.run_button.setDisabled(False) - - def on_cancel_model_progress_dialog(self): - if self.model_loader is not None: - self.model_loader.cancel() - self.reset_model_download() - - def reset_model_download(self): - if self.model_download_progress_dialog is not None: - self.model_download_progress_dialog.close() - self.model_download_progress_dialog = None - - def on_word_level_timings_changed(self, value: int): - self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value - - def closeEvent(self, event: QtGui.QCloseEvent) -> None: - if self.model_loader is not None: - self.model_loader.cancel() - - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language) - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task) - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, self.transcription_options.temperature) - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, self.transcription_options.initial_prompt) - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL, self.transcription_options.model) - self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, - value=self.transcription_options.word_level_timings) - self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, - value=[export_format.value for export_format in - self.file_transcription_options.output_formats]) - - super().closeEvent(event) - - -class AdvancedSettingsButton(QPushButton): - def __init__(self, parent: Optional[QWidget]) -> None: - super().__init__('Advanced...', parent) - - class AudioMeterWidget(QWidget): current_amplitude: float BAR_WIDTH = 2 @@ -730,6 +509,9 @@ class MainWindow(QMainWindow): self.shortcut_settings = ShortcutSettings(settings=self.settings) self.shortcuts = self.shortcut_settings.load() + self.default_export_file_name = self.settings.value( + Settings.Key.DEFAULT_EXPORT_FILE_NAME, + '{{ input_file_name }} ({{ task }}d on {{ date_time }})') self.tasks = {} self.tasks_changed.connect(self.on_tasks_changed) @@ -742,11 +524,14 @@ class MainWindow(QMainWindow): self.addToolBar(self.toolbar) self.setUnifiedTitleAndToolBarOnMac(True) - self.menu_bar = MenuBar(shortcuts=self.shortcuts, parent=self) + self.menu_bar = MenuBar(shortcuts=self.shortcuts, + default_export_file_name=self.default_export_file_name, + parent=self) self.menu_bar.import_action_triggered.connect( self.on_new_transcription_action_triggered) self.menu_bar.shortcuts_changed.connect(self.on_shortcuts_changed) self.menu_bar.openai_api_key_changed.connect(self.on_openai_access_token_changed) + self.menu_bar.default_export_file_name_changed.connect(self.default_export_file_name_changed) self.setMenuBar(self.menu_bar) self.table_widget = TranscriptionTasksTableWidget(self) @@ -840,6 +625,7 @@ class MainWindow(QMainWindow): def open_file_transcriber_widget(self, file_paths: List[str]): file_transcriber_window = FileTranscriberWidget(file_paths=file_paths, + default_output_file_name=self.default_export_file_name, parent=self, flags=Qt.WindowType.Window) file_transcriber_window.triggered.connect( @@ -851,6 +637,10 @@ class MainWindow(QMainWindow): def on_openai_access_token_changed(access_token: str): KeyringStore().set_password(KeyringStore.Key.OPENAI_API_KEY, access_token) + def default_export_file_name_changed(self, default_export_file_name: str): + self.default_export_file_name = default_export_file_name + self.settings.set_value(Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name) + def open_transcript_viewer(self): selected_rows = self.table_widget.selectionModel().selectedRows() for selected_row in selected_rows: @@ -933,241 +723,6 @@ class MainWindow(QMainWindow): super().closeEvent(event) -# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py -class HuggingFaceSearchLineEdit(LineEdit): - model_selected = pyqtSignal(str) - popup: QListWidget - - def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None, - parent: Optional[QWidget] = None): - super().__init__('', parent) - - self.setMinimumWidth(150) - self.setPlaceholderText('openai/whisper-tiny') - - self.timer = QTimer(self) - self.timer.setSingleShot(True) - self.timer.setInterval(250) - self.timer.timeout.connect(self.fetch_models) - - # Restart debounce timer each time editor text changes - self.textEdited.connect(self.timer.start) - self.textEdited.connect(self.on_text_edited) - - if network_access_manager is None: - network_access_manager = QNetworkAccessManager(self) - - self.network_manager = network_access_manager - self.network_manager.finished.connect(self.on_request_response) - - self.popup = QListWidget() - self.popup.setWindowFlags(Qt.WindowType.Popup) - self.popup.setFocusProxy(self) - self.popup.setMouseTracking(True) - self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) - self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) - self.popup.installEventFilter(self) - self.popup.itemClicked.connect(self.on_select_item) - - def on_text_edited(self, text: str): - self.model_selected.emit(text) - - def on_select_item(self): - self.popup.hide() - self.setFocus() - - item = self.popup.currentItem() - self.setText(item.text()) - QMetaObject.invokeMethod(self, 'returnPressed') - self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole)) - - def fetch_models(self): - text = self.text() - if len(text) < 3: - return - - url = QUrl("https://huggingface.co/api/models") - - query = QUrlQuery() - query.addQueryItem("filter", "whisper") - query.addQueryItem("search", text) - - url.setQuery(query) - - return self.network_manager.get(QNetworkRequest(url)) - - def on_popup_selected(self): - self.timer.stop() - - def on_request_response(self, network_reply: QNetworkReply): - if network_reply.error() != QNetworkReply.NetworkError.NoError: - logging.debug('Error fetching Hugging Face models: %s', network_reply.error()) - return - - models = json.loads(network_reply.readAll().data()) - - self.popup.setUpdatesEnabled(False) - self.popup.clear() - - for model in models: - model_id = model.get('id') - - item = QListWidgetItem(self.popup) - item.setText(model_id) - item.setData(Qt.ItemDataRole.UserRole, model_id) - - self.popup.setCurrentItem(self.popup.item(0)) - self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20) - self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll - self.popup.setUpdatesEnabled(True) - self.popup.move(self.mapToGlobal(QPoint(0, self.height()))) - self.popup.setFocus() - self.popup.show() - - def eventFilter(self, target: QObject, event: QEvent): - if hasattr(self, 'popup') is False or target != self.popup: - return False - - if event.type() == QEvent.Type.MouseButtonPress: - self.popup.hide() - self.setFocus() - return True - - if isinstance(event, QKeyEvent): - key = event.key() - if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]: - if self.popup.currentItem() is not None: - self.on_select_item() - return True - - if key == Qt.Key.Key_Escape: - self.setFocus() - self.popup.hide() - return True - - if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp, - Qt.Key.Key_PageDown]: - return False - - self.setFocus() - self.event(event) - self.popup.hide() - - return False - - -class TranscriptionOptionsGroupBox(QGroupBox): - transcription_options: TranscriptionOptions - transcription_options_changed = pyqtSignal(TranscriptionOptions) - - def __init__(self, default_transcription_options: TranscriptionOptions = TranscriptionOptions(), - model_types: Optional[List[ModelType]] = None, - parent: Optional[QWidget] = None): - super().__init__(title='', parent=parent) - self.transcription_options = default_transcription_options - - self.form_layout = QFormLayout(self) - - self.tasks_combo_box = TasksComboBox( - default_task=self.transcription_options.task, - parent=self) - self.tasks_combo_box.taskChanged.connect(self.on_task_changed) - - self.languages_combo_box = LanguagesComboBox( - default_language=self.transcription_options.language, - parent=self) - self.languages_combo_box.languageChanged.connect( - self.on_language_changed) - - self.advanced_settings_button = AdvancedSettingsButton(self) - self.advanced_settings_button.clicked.connect( - self.open_advanced_settings) - - self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit() - self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed) - - self.model_type_combo_box = ModelTypeComboBox(model_types=model_types, - default_model=default_transcription_options.model.model_type, - parent=self) - self.model_type_combo_box.changed.connect(self.on_model_type_changed) - - self.whisper_model_size_combo_box = QComboBox(self) - self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize]) - if default_transcription_options.model.whisper_model_size is not None: - self.whisper_model_size_combo_box.setCurrentText( - default_transcription_options.model.whisper_model_size.value.title()) - self.whisper_model_size_combo_box.currentTextChanged.connect(self.on_whisper_model_size_changed) - - self.openai_access_token_edit = OpenAIAPIKeyLineEdit(key=default_transcription_options.openai_access_token, - parent=self) - self.openai_access_token_edit.key_changed.connect(self.on_openai_access_token_edit_changed) - - self.form_layout.addRow(_('Model:'), self.model_type_combo_box) - self.form_layout.addRow('', self.whisper_model_size_combo_box) - self.form_layout.addRow('', self.hugging_face_search_line_edit) - self.form_layout.addRow('Access Token:', self.openai_access_token_edit) - self.form_layout.addRow(_('Task:'), self.tasks_combo_box) - self.form_layout.addRow(_('Language:'), self.languages_combo_box) - - self.reset_visible_rows() - - self.form_layout.addRow('', self.advanced_settings_button) - - self.setLayout(self.form_layout) - - def on_openai_access_token_edit_changed(self, access_token: str): - self.transcription_options.openai_access_token = access_token - self.transcription_options_changed.emit(self.transcription_options) - - def on_language_changed(self, language: str): - self.transcription_options.language = language - self.transcription_options_changed.emit(self.transcription_options) - - def on_task_changed(self, task: Task): - self.transcription_options.task = task - self.transcription_options_changed.emit(self.transcription_options) - - def on_temperature_changed(self, temperature: Tuple[float, ...]): - self.transcription_options.temperature = temperature - self.transcription_options_changed.emit(self.transcription_options) - - def on_initial_prompt_changed(self, initial_prompt: str): - self.transcription_options.initial_prompt = initial_prompt - self.transcription_options_changed.emit(self.transcription_options) - - def open_advanced_settings(self): - dialog = AdvancedSettingsDialog( - transcription_options=self.transcription_options, parent=self) - dialog.transcription_options_changed.connect( - self.on_transcription_options_changed) - dialog.exec() - - def on_transcription_options_changed(self, transcription_options: TranscriptionOptions): - self.transcription_options = transcription_options - self.transcription_options_changed.emit(transcription_options) - - def reset_visible_rows(self): - model_type = self.transcription_options.model.model_type - self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE) - self.form_layout.setRowVisible(self.whisper_model_size_combo_box, - (model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP) or ( - model_type == ModelType.FASTER_WHISPER)) - self.form_layout.setRowVisible(self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API) - - def on_model_type_changed(self, model_type: ModelType): - self.transcription_options.model.model_type = model_type - self.reset_visible_rows() - self.transcription_options_changed.emit(self.transcription_options) - - def on_whisper_model_size_changed(self, text: str): - model_size = WhisperModelSize(text.lower()) - self.transcription_options.model.whisper_model_size = model_size - self.transcription_options_changed.emit(self.transcription_options) - - def on_hugging_face_model_changed(self, model: str): - self.transcription_options.model.hugging_face_model_id = model - self.transcription_options_changed.emit(self.transcription_options) - class Application(QApplication): window: MainWindow @@ -1183,73 +738,3 @@ class Application(QApplication): def add_task(self, task: FileTranscriptionTask): self.window.add_task(task) - - -class AdvancedSettingsDialog(QDialog): - transcription_options: TranscriptionOptions - transcription_options_changed = pyqtSignal(TranscriptionOptions) - - def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None): - super().__init__(parent) - - self.transcription_options = transcription_options - - self.setWindowTitle(_('Advanced Settings')) - - button_box = QDialogButtonBox(QDialogButtonBox.StandardButton( - QDialogButtonBox.StandardButton.Ok), self) - button_box.accepted.connect(self.accept) - - layout = QFormLayout(self) - - default_temperature_text = ', '.join( - [str(temp) for temp in transcription_options.temperature]) - self.temperature_line_edit = LineEdit(default_temperature_text, self) - self.temperature_line_edit.setPlaceholderText( - _('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"')) - self.temperature_line_edit.setMinimumWidth(170) - self.temperature_line_edit.textChanged.connect( - self.on_temperature_changed) - self.temperature_line_edit.setValidator(TemperatureValidator(self)) - self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER) - - self.initial_prompt_text_edit = QPlainTextEdit( - transcription_options.initial_prompt, self) - self.initial_prompt_text_edit.textChanged.connect( - self.on_initial_prompt_changed) - self.initial_prompt_text_edit.setEnabled( - transcription_options.model.model_type == ModelType.WHISPER) - - layout.addRow(_('Temperature:'), self.temperature_line_edit) - layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit) - layout.addWidget(button_box) - - self.setLayout(layout) - self.setFixedSize(self.sizeHint()) - - def on_temperature_changed(self, text: str): - try: - temperatures = [float(temp.strip()) for temp in text.split(',')] - self.transcription_options.temperature = tuple(temperatures) - self.transcription_options_changed.emit(self.transcription_options) - except ValueError: - pass - - def on_initial_prompt_changed(self): - self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText() - self.transcription_options_changed.emit(self.transcription_options) - - -class TemperatureValidator(QValidator): - def __init__(self, parent: Optional[QObject] = ...) -> None: - super().__init__(parent) - - def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]: - try: - temp_strings = [temp.strip() for temp in text.split(',')] - if temp_strings[-1] == '': - return QValidator.State.Intermediate, text, cursor_position - _ = [float(temp) for temp in temp_strings] - return QValidator.State.Acceptable, text, cursor_position - except ValueError: - return QValidator.State.Invalid, text, cursor_position diff --git a/buzz/model_loader.py b/buzz/model_loader.py index b3538899..d4597165 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -19,7 +19,7 @@ from platformdirs import user_cache_dir from tqdm.auto import tqdm -class WhisperModelSize(enum.Enum): +class WhisperModelSize(str, enum.Enum): TINY = 'tiny' BASE = 'base' SMALL = 'small' diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index 524b0144..28614573 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -25,14 +25,18 @@ class Settings: FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = 'file-transcriber/word-level-timings' FILE_TRANSCRIBER_EXPORT_FORMATS = 'file-transcriber/export-formats' + DEFAULT_EXPORT_FILE_NAME = 'transcriber/default-export-file-name' + SHORTCUTS = 'shortcuts' def set_value(self, key: Key, value: typing.Any) -> None: self.settings.setValue(key.value, value) - def value(self, key: Key, default_value: typing.Any, value_type: typing.Optional[type] = None) -> typing.Any: + def value(self, key: Key, default_value: typing.Any, + value_type: typing.Optional[type] = None) -> typing.Any: return self.settings.value(key.value, default_value, - value_type if value_type is not None else type(default_value)) + value_type if value_type is not None else type( + default_value)) def clear(self): self.settings.clear() diff --git a/buzz/transcriber.py b/buzz/transcriber.py index e22a64aa..48316a38 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -66,13 +66,15 @@ class TranscriptionOptions: word_level_timings: bool = False temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE initial_prompt: str = '' - openai_access_token: str = field(default='', metadata=config(exclude=Exclude.ALWAYS)) + openai_access_token: str = field(default='', + metadata=config(exclude=Exclude.ALWAYS)) @dataclass() class FileTranscriptionOptions: file_paths: List[str] output_formats: Set['OutputFormat'] = field(default_factory=set) + default_output_file_name: str = '' @dataclass_json @@ -127,12 +129,11 @@ class FileTranscriber(QObject): self.completed.emit(segments) for output_format in self.transcription_task.file_transcription_options.output_formats: - default_path = get_default_output_file_path( - task=self.transcription_task.transcription_options.task, - input_file_path=self.transcription_task.file_path, - output_format=output_format) + default_path = get_default_output_file_path(task=self.transcription_task, + output_format=output_format) - write_output(path=default_path, segments=segments, output_format=output_format) + write_output(path=default_path, segments=segments, + output_format=output_format) @abstractmethod def transcribe(self) -> List[Segment]: @@ -172,17 +173,22 @@ class WhisperCppFileTranscriber(FileTranscriber): logging.debug( 'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, ' 'word level timings = %s', - self.file_path, self.language, self.task, model_path, self.word_level_timings) + self.file_path, self.language, self.task, model_path, + self.word_level_timings) audio = whisper.audio.load_audio(self.file_path) self.duration_audio_ms = len(audio) * 1000 / whisper.audio.SAMPLE_RATE - whisper_params = whisper_cpp_params(language=self.language if self.language is not None else '', task=self.task, - word_level_timings=self.word_level_timings) - whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(id(self.state)) - whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(self.encoder_begin_callback) + whisper_params = whisper_cpp_params( + language=self.language if self.language is not None else '', task=self.task, + word_level_timings=self.word_level_timings) + whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p( + id(self.state)) + whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback( + self.encoder_begin_callback) whisper_params.new_segment_callback_user_data = ctypes.c_void_p(id(self.state)) - whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(self.new_segment_callback) + whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback( + self.new_segment_callback) model = WhisperCpp(model=model_path) result = model.transcribe(audio=self.file_path, params=whisper_params) @@ -199,13 +205,15 @@ class WhisperCppFileTranscriber(FileTranscriber): # t1 seems to sometimes be larger than the duration when the # audio ends in silence. Trim to fix the displayed progress. progress = min(t1 * 10, self.duration_audio_ms) - state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, ctypes.py_object).value + state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, + ctypes.py_object).value if state.running: self.progress.emit((progress, self.duration_audio_ms)) @staticmethod def encoder_begin_callback(_ctx, _state, user_data): - state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, ctypes.py_object).value + state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, + ctypes.py_object).value return state.running == 1 def stop(self): @@ -219,8 +227,10 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber): self.task = task.transcription_options.task def transcribe(self) -> List[Segment]: - logging.debug('Starting OpenAI Whisper API file transcription, file path = %s, task = %s', self.file_path, - self.task) + logging.debug( + 'Starting OpenAI Whisper API file transcription, file path = %s, task = %s', + self.file_path, + self.task) wav_file = tempfile.mktemp() + '.wav' ( @@ -235,14 +245,18 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber): language = self.transcription_task.transcription_options.language response_format = "verbose_json" if self.transcription_task.transcription_options.task == Task.TRANSLATE: - transcript = openai.Audio.translate("whisper-1", audio_file, response_format=response_format, + transcript = openai.Audio.translate("whisper-1", audio_file, + response_format=response_format, language=language) else: - transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format=response_format, + transcript = openai.Audio.transcribe("whisper-1", audio_file, + response_format=response_format, language=language) - segments = [Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for segment in - transcript["segments"]] + segments = [ + Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for + segment in + transcript["segments"]] return segments def stop(self): @@ -273,7 +287,8 @@ class WhisperFileTranscriber(FileTranscriber): recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False) self.current_process = multiprocessing.Process(target=self.transcribe_whisper, - args=(send_pipe, self.transcription_task)) + args=(send_pipe, + self.transcription_task)) if not self.stopped: self.current_process.start() self.started_process = True @@ -291,7 +306,8 @@ class WhisperFileTranscriber(FileTranscriber): logging.debug( 'whisper process completed with code = %s, time taken = %s, number of segments = %s', - self.current_process.exitcode, datetime.datetime.now() - time_started, len(self.segments)) + self.current_process.exitcode, datetime.datetime.now() - time_started, + len(self.segments)) if self.current_process.exitcode != 0: raise Exception('Unknown error') @@ -299,7 +315,8 @@ class WhisperFileTranscriber(FileTranscriber): return self.segments @classmethod - def transcribe_whisper(cls, stderr_conn: Connection, task: FileTranscriptionTask) -> None: + def transcribe_whisper(cls, stderr_conn: Connection, + task: FileTranscriptionTask) -> None: with pipe_stderr(stderr_conn): if task.transcription_options.model.model_type == ModelType.HUGGING_FACE: segments = cls.transcribe_hugging_face(task) @@ -308,7 +325,8 @@ class WhisperFileTranscriber(FileTranscriber): elif task.transcription_options.model.model_type == ModelType.WHISPER: segments = cls.transcribe_openai_whisper(task) else: - raise Exception(f"Invalid model type: {task.transcription_options.model.model_type}") + raise Exception( + f"Invalid model type: {task.transcription_options.model.model_type}") segments_json = json.dumps( segments, ensure_ascii=True, default=vars) @@ -321,7 +339,8 @@ class WhisperFileTranscriber(FileTranscriber): model = transformers_whisper.load_model(task.model_path) language = task.transcription_options.language if task.transcription_options.language is not None else 'en' result = model.transcribe(audio=task.file_path, language=language, - task=task.transcription_options.task.value, verbose=False) + task=task.transcription_options.task.value, + verbose=False) return [ Segment( start=int(segment.get('start') * 1000), @@ -368,7 +387,8 @@ class WhisperFileTranscriber(FileTranscriber): stable_whisper.modify_model(model) result = model.transcribe( audio=task.file_path, language=task.transcription_options.language, - task=task.transcription_options.task.value, temperature=task.transcription_options.temperature, + task=task.transcription_options.task.value, + temperature=task.transcription_options.temperature, initial_prompt=task.transcription_options.initial_prompt, pbar=True) segments = stable_whisper.group_word_timestamps(result) return [Segment( @@ -423,7 +443,8 @@ class WhisperFileTranscriber(FileTranscriber): def write_output(path: str, segments: List[Segment], output_format: OutputFormat): logging.debug( - 'Writing transcription output, path = %s, output format = %s, number of segments = %s', path, output_format, + 'Writing transcription output, path = %s, output format = %s, number of segments = %s', + path, output_format, len(segments)) with open(path, 'w', encoding='utf-8') as file: @@ -473,8 +494,22 @@ SUPPORTED_OUTPUT_FORMATS = 'Audio files (*.mp3 *.wav *.m4a *.ogg);;\ Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)' -def get_default_output_file_path(task: Task, input_file_path: str, output_format: OutputFormat): - return f'{os.path.splitext(input_file_path)[0]} ({task.value.title()}d on {datetime.datetime.now():%d-%b-%Y %H-%M-%S}).{output_format.value}' +def get_default_output_file_path(task: FileTranscriptionTask, + output_format: OutputFormat): + input_file_name = os.path.splitext(task.file_path)[0] + date_time_now = datetime.datetime.now().strftime('%d-%b-%Y %H-%M-%S') + return (task.file_transcription_options.default_output_file_name + .replace('{{ input_file_name }}', input_file_name) + .replace('{{ task }}', task.transcription_options.task.value) + .replace('{{ language }}', task.transcription_options.language or '') + .replace('{{ model_type }}', + task.transcription_options.model.model_type.value) + .replace('{{ model_size }}', + task.transcription_options.model.whisper_model_size.value if + task.transcription_options.model.whisper_model_size is not None else + '') + .replace('{{ date_time }}', date_time_now) + + f".{output_format.value}") def whisper_cpp_params( diff --git a/buzz/widgets/menu_bar.py b/buzz/widgets/menu_bar.py index 46dd30f5..65ae2f91 100644 --- a/buzz/widgets/menu_bar.py +++ b/buzz/widgets/menu_bar.py @@ -1,3 +1,4 @@ +import webbrowser from typing import Dict from PyQt6.QtCore import pyqtSignal @@ -15,11 +16,14 @@ class MenuBar(QMenuBar): import_action_triggered = pyqtSignal() shortcuts_changed = pyqtSignal(dict) openai_api_key_changed = pyqtSignal(str) + default_export_file_name_changed = pyqtSignal(str) - def __init__(self, shortcuts: Dict[str, str], parent: QWidget): + def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str, + parent: QWidget): super().__init__(parent) self.shortcuts = shortcuts + self.default_export_file_name = default_export_file_name self.import_action = QAction(_("Import Media File..."), self) self.import_action.triggered.connect( @@ -31,6 +35,9 @@ class MenuBar(QMenuBar): self.preferences_action = QAction(_("Preferences..."), self) self.preferences_action.triggered.connect(self.on_preferences_action_triggered) + help_action = QAction(f'{_("Help")}', self) + help_action.triggered.connect(self.on_help_action_triggered) + self.set_shortcuts(shortcuts) file_menu = self.addMenu(_("File")) @@ -38,6 +45,7 @@ class MenuBar(QMenuBar): help_menu = self.addMenu(_("Help")) help_menu.addAction(about_action) + help_menu.addAction(help_action) help_menu.addAction(self.preferences_action) def on_import_action_triggered(self): @@ -48,11 +56,18 @@ class MenuBar(QMenuBar): about_dialog.open() def on_preferences_action_triggered(self): - preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts, parent=self) + preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts, + default_export_file_name=self.default_export_file_name, + parent=self) preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed) preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed) + preferences_dialog.default_export_file_name_changed.connect( + self.default_export_file_name_changed) preferences_dialog.open() + def on_help_action_triggered(self): + webbrowser.open('https://chidiwilliams.github.io/buzz/docs') + def set_shortcuts(self, shortcuts: Dict[str, str]): self.shortcuts = shortcuts diff --git a/buzz/widgets/preferences_dialog/general_preferences_widget.py b/buzz/widgets/preferences_dialog/general_preferences_widget.py index 121dd5b8..25e07afa 100644 --- a/buzz/widgets/preferences_dialog/general_preferences_widget.py +++ b/buzz/widgets/preferences_dialog/general_preferences_widget.py @@ -3,33 +3,45 @@ from typing import Optional import openai from PyQt6.QtCore import QRunnable, QObject, pyqtSignal, QThreadPool -from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox +from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox, QLineEdit from openai.error import AuthenticationError from buzz.store.keyring_store import KeyringStore +from buzz.widgets.line_edit import LineEdit from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit class GeneralPreferencesWidget(QWidget): openai_api_key_changed = pyqtSignal(str) + default_export_file_name_changed = pyqtSignal(str) - def __init__(self, keyring_store=KeyringStore(), parent: Optional[QWidget] = None): + def __init__(self, default_export_file_name: str, keyring_store=KeyringStore(), + parent: Optional[QWidget] = None): super().__init__(parent) - self.openai_api_key = keyring_store.get_password(KeyringStore.Key.OPENAI_API_KEY) + self.openai_api_key = keyring_store.get_password( + KeyringStore.Key.OPENAI_API_KEY) layout = QFormLayout(self) self.openai_api_key_line_edit = OpenAIAPIKeyLineEdit(self.openai_api_key, self) - self.openai_api_key_line_edit.key_changed.connect(self.on_openai_api_key_changed) + self.openai_api_key_line_edit.key_changed.connect( + self.on_openai_api_key_changed) self.test_openai_api_key_button = QPushButton('Test') - self.test_openai_api_key_button.clicked.connect(self.on_click_test_openai_api_key_button) + self.test_openai_api_key_button.clicked.connect( + self.on_click_test_openai_api_key_button) self.update_test_openai_api_key_button() layout.addRow('OpenAI API Key', self.openai_api_key_line_edit) layout.addRow('', self.test_openai_api_key_button) + default_export_file_name_line_edit = LineEdit(default_export_file_name, self) + default_export_file_name_line_edit.textChanged.connect( + self.default_export_file_name_changed) + default_export_file_name_line_edit.setMinimumWidth(200) + layout.addRow('Default export file name', default_export_file_name_line_edit) + self.setLayout(layout) def update_test_openai_api_key_button(self): diff --git a/buzz/widgets/preferences_dialog/preferences_dialog.py b/buzz/widgets/preferences_dialog/preferences_dialog.py index be5e5cbf..fd68b7c7 100644 --- a/buzz/widgets/preferences_dialog/preferences_dialog.py +++ b/buzz/widgets/preferences_dialog/preferences_dialog.py @@ -15,8 +15,9 @@ from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import class PreferencesDialog(QDialog): shortcuts_changed = pyqtSignal(dict) openai_api_key_changed = pyqtSignal(str) + default_export_file_name_changed = pyqtSignal(str) - def __init__(self, shortcuts: Dict[str, str], + def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str, parent: Optional[QWidget] = None) -> None: super().__init__(parent) @@ -25,8 +26,11 @@ class PreferencesDialog(QDialog): layout = QVBoxLayout(self) tab_widget = QTabWidget(self) - general_tab_widget = GeneralPreferencesWidget(parent=self) + general_tab_widget = GeneralPreferencesWidget( + default_export_file_name=default_export_file_name, parent=self) general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed) + general_tab_widget.default_export_file_name_changed.connect( + self.default_export_file_name_changed) tab_widget.addTab(general_tab_widget, _('General')) models_tab_widget = ModelsPreferencesWidget(parent=self) diff --git a/buzz/widgets/transcriber/__init__.py b/buzz/widgets/transcriber/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/buzz/widgets/transcriber/advanced_settings_button.py b/buzz/widgets/transcriber/advanced_settings_button.py new file mode 100644 index 00000000..84100559 --- /dev/null +++ b/buzz/widgets/transcriber/advanced_settings_button.py @@ -0,0 +1,8 @@ +from typing import Optional + +from PyQt6.QtWidgets import QPushButton, QWidget + + +class AdvancedSettingsButton(QPushButton): + def __init__(self, parent: Optional[QWidget]) -> None: + super().__init__('Advanced...', parent) diff --git a/buzz/widgets/transcriber/advanced_settings_dialog.py b/buzz/widgets/transcriber/advanced_settings_dialog.py new file mode 100644 index 00000000..172760cc --- /dev/null +++ b/buzz/widgets/transcriber/advanced_settings_dialog.py @@ -0,0 +1,64 @@ +from PyQt6.QtCore import pyqtSignal +from PyQt6.QtWidgets import QDialog, QWidget, QDialogButtonBox, QFormLayout, \ + QPlainTextEdit + +from buzz.widgets.transcriber.temperature_validator import TemperatureValidator +from buzz.locale import _ +from buzz.model_loader import ModelType +from buzz.transcriber import TranscriptionOptions +from buzz.widgets.line_edit import LineEdit + + +class AdvancedSettingsDialog(QDialog): + transcription_options: TranscriptionOptions + transcription_options_changed = pyqtSignal(TranscriptionOptions) + + def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None): + super().__init__(parent) + + self.transcription_options = transcription_options + + self.setWindowTitle(_('Advanced Settings')) + + button_box = QDialogButtonBox(QDialogButtonBox.StandardButton( + QDialogButtonBox.StandardButton.Ok), self) + button_box.accepted.connect(self.accept) + + layout = QFormLayout(self) + + default_temperature_text = ', '.join( + [str(temp) for temp in transcription_options.temperature]) + self.temperature_line_edit = LineEdit(default_temperature_text, self) + self.temperature_line_edit.setPlaceholderText( + _('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"')) + self.temperature_line_edit.setMinimumWidth(170) + self.temperature_line_edit.textChanged.connect( + self.on_temperature_changed) + self.temperature_line_edit.setValidator(TemperatureValidator(self)) + self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER) + + self.initial_prompt_text_edit = QPlainTextEdit( + transcription_options.initial_prompt, self) + self.initial_prompt_text_edit.textChanged.connect( + self.on_initial_prompt_changed) + self.initial_prompt_text_edit.setEnabled( + transcription_options.model.model_type == ModelType.WHISPER) + + layout.addRow(_('Temperature:'), self.temperature_line_edit) + layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit) + layout.addWidget(button_box) + + self.setLayout(layout) + self.setFixedSize(self.sizeHint()) + + def on_temperature_changed(self, text: str): + try: + temperatures = [float(temp.strip()) for temp in text.split(',')] + self.transcription_options.temperature = tuple(temperatures) + self.transcription_options_changed.emit(self.transcription_options) + except ValueError: + pass + + def on_initial_prompt_changed(self): + self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText() + self.transcription_options_changed.emit(self.transcription_options) diff --git a/buzz/widgets/transcriber/file_transcriber_widget.py b/buzz/widgets/transcriber/file_transcriber_widget.py new file mode 100644 index 00000000..ebcd6629 --- /dev/null +++ b/buzz/widgets/transcriber/file_transcriber_widget.py @@ -0,0 +1,204 @@ +from typing import Optional, List, Tuple + +from PyQt6 import QtGui +from PyQt6.QtCore import pyqtSignal, Qt, QThreadPool +from PyQt6.QtWidgets import QWidget, QVBoxLayout, QCheckBox, QFormLayout, QHBoxLayout, \ + QPushButton + +from buzz.dialogs import show_model_download_error_dialog +from buzz.locale import _ +from buzz.model_loader import ModelDownloader, TranscriptionModel, ModelType +from buzz.paths import file_paths_as_title +from buzz.settings.settings import Settings +from buzz.store.keyring_store import KeyringStore +from buzz.transcriber import FileTranscriptionOptions, TranscriptionOptions, Task, \ + DEFAULT_WHISPER_TEMPERATURE, OutputFormat +from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog +from buzz.widgets.transcriber.transcription_options_group_box import \ + TranscriptionOptionsGroupBox + + +class FileTranscriberWidget(QWidget): + model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None + model_loader: Optional[ModelDownloader] = None + file_transcription_options: FileTranscriptionOptions + transcription_options: TranscriptionOptions + is_transcribing = False + # (TranscriptionOptions, FileTranscriptionOptions, str) + triggered = pyqtSignal(tuple) + openai_access_token_changed = pyqtSignal(str) + settings = Settings() + + def __init__(self, file_paths: List[str], + default_output_file_name: str, + parent: Optional[QWidget] = None, + flags: Qt.WindowType = Qt.WindowType.Widget) -> None: + super().__init__(parent, flags) + + self.setWindowTitle(file_paths_as_title(file_paths)) + + openai_access_token = KeyringStore().get_password( + KeyringStore.Key.OPENAI_API_KEY) + + self.file_paths = file_paths + default_language = self.settings.value( + key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='') + self.transcription_options = TranscriptionOptions( + openai_access_token=openai_access_token, + model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL, + default_value=TranscriptionModel()), + task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK, + default_value=Task.TRANSCRIBE), + language=default_language if default_language != '' else None, + initial_prompt=self.settings.value( + key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''), + temperature=self.settings.value( + key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, + default_value=DEFAULT_WHISPER_TEMPERATURE), + word_level_timings=self.settings.value( + key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, + default_value=False)) + default_export_format_states: List[str] = self.settings.value( + key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, + default_value=[]) + self.file_transcription_options = FileTranscriptionOptions( + file_paths=self.file_paths, + output_formats=set([OutputFormat(output_format) for output_format in + default_export_format_states]), + default_output_file_name=default_output_file_name) + + layout = QVBoxLayout(self) + + transcription_options_group_box = TranscriptionOptionsGroupBox( + default_transcription_options=self.transcription_options, parent=self) + transcription_options_group_box.transcription_options_changed.connect( + self.on_transcription_options_changed) + + self.word_level_timings_checkbox = QCheckBox(_('Word-level timings')) + self.word_level_timings_checkbox.setChecked( + self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, + default_value=False)) + self.word_level_timings_checkbox.stateChanged.connect( + self.on_word_level_timings_changed) + + file_transcription_layout = QFormLayout() + file_transcription_layout.addRow('', self.word_level_timings_checkbox) + + export_format_layout = QHBoxLayout() + for output_format in OutputFormat: + export_format_checkbox = QCheckBox(f'{output_format.value.upper()}', + parent=self) + export_format_checkbox.setChecked( + output_format in self.file_transcription_options.output_formats) + export_format_checkbox.stateChanged.connect( + self.get_on_checkbox_state_changed_callback(output_format)) + export_format_layout.addWidget(export_format_checkbox) + + file_transcription_layout.addRow('Export:', export_format_layout) + + self.run_button = QPushButton(_('Run'), self) + self.run_button.setDefault(True) + self.run_button.clicked.connect(self.on_click_run) + + layout.addWidget(transcription_options_group_box) + layout.addLayout(file_transcription_layout) + layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight) + + self.setLayout(layout) + self.setFixedSize(self.sizeHint()) + + def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat): + def on_checkbox_state_changed(state: int): + if state == Qt.CheckState.Checked.value: + self.file_transcription_options.output_formats.add(output_format) + elif state == Qt.CheckState.Unchecked.value: + self.file_transcription_options.output_formats.remove(output_format) + + return on_checkbox_state_changed + + def on_transcription_options_changed(self, + transcription_options: TranscriptionOptions): + self.transcription_options = transcription_options + self.word_level_timings_checkbox.setDisabled( + self.transcription_options.model.model_type == ModelType.HUGGING_FACE or + self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API) + if self.transcription_options.openai_access_token != '': + self.openai_access_token_changed.emit( + self.transcription_options.openai_access_token) + + def on_click_run(self): + self.run_button.setDisabled(True) + + model_path = self.transcription_options.model.get_local_model_path() + if model_path is not None: + self.on_model_loaded(model_path) + return + + self.model_loader = ModelDownloader(model=self.transcription_options.model) + self.model_loader.signals.progress.connect(self.on_download_model_progress) + self.model_loader.signals.error.connect(self.on_download_model_error) + self.model_loader.signals.finished.connect(self.on_model_loaded) + QThreadPool().globalInstance().start(self.model_loader) + + def on_model_loaded(self, model_path: str): + self.reset_transcriber_controls() + + self.triggered.emit((self.transcription_options, + self.file_transcription_options, model_path)) + self.close() + + def on_download_model_progress(self, progress: Tuple[float, float]): + (current_size, total_size) = progress + + if self.model_download_progress_dialog is None: + self.model_download_progress_dialog = ModelDownloadProgressDialog( + model_type=self.transcription_options.model.model_type, parent=self) + self.model_download_progress_dialog.canceled.connect( + self.on_cancel_model_progress_dialog) + + if self.model_download_progress_dialog is not None: + self.model_download_progress_dialog.set_value( + fraction_completed=current_size / total_size) + + def on_download_model_error(self, error: str): + self.reset_model_download() + show_model_download_error_dialog(self, error) + self.reset_transcriber_controls() + + def reset_transcriber_controls(self): + self.run_button.setDisabled(False) + + def on_cancel_model_progress_dialog(self): + if self.model_loader is not None: + self.model_loader.cancel() + self.reset_model_download() + + def reset_model_download(self): + if self.model_download_progress_dialog is not None: + self.model_download_progress_dialog.close() + self.model_download_progress_dialog = None + + def on_word_level_timings_changed(self, value: int): + self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value + + def closeEvent(self, event: QtGui.QCloseEvent) -> None: + if self.model_loader is not None: + self.model_loader.cancel() + + self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE, + self.transcription_options.language) + self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK, + self.transcription_options.task) + self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, + self.transcription_options.temperature) + self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, + self.transcription_options.initial_prompt) + self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL, + self.transcription_options.model) + self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, + value=self.transcription_options.word_level_timings) + self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, + value=[export_format.value for export_format in + self.file_transcription_options.output_formats]) + + super().closeEvent(event) diff --git a/buzz/widgets/transcriber/hugging_face_search_line_edit.py b/buzz/widgets/transcriber/hugging_face_search_line_edit.py new file mode 100644 index 00000000..4c98a516 --- /dev/null +++ b/buzz/widgets/transcriber/hugging_face_search_line_edit.py @@ -0,0 +1,134 @@ +import json +import logging +from typing import Optional + +from PyQt6.QtCore import pyqtSignal, QTimer, Qt, QMetaObject, QUrl, QUrlQuery, QPoint, \ + QObject, QEvent +from PyQt6.QtGui import QKeyEvent +from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply +from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem + +from buzz.widgets.line_edit import LineEdit + + +# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py +class HuggingFaceSearchLineEdit(LineEdit): + model_selected = pyqtSignal(str) + popup: QListWidget + + def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None, + parent: Optional[QWidget] = None): + super().__init__('', parent) + + self.setMinimumWidth(150) + self.setPlaceholderText('openai/whisper-tiny') + + self.timer = QTimer(self) + self.timer.setSingleShot(True) + self.timer.setInterval(250) + self.timer.timeout.connect(self.fetch_models) + + # Restart debounce timer each time editor text changes + self.textEdited.connect(self.timer.start) + self.textEdited.connect(self.on_text_edited) + + if network_access_manager is None: + network_access_manager = QNetworkAccessManager(self) + + self.network_manager = network_access_manager + self.network_manager.finished.connect(self.on_request_response) + + self.popup = QListWidget() + self.popup.setWindowFlags(Qt.WindowType.Popup) + self.popup.setFocusProxy(self) + self.popup.setMouseTracking(True) + self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) + self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.popup.installEventFilter(self) + self.popup.itemClicked.connect(self.on_select_item) + + def on_text_edited(self, text: str): + self.model_selected.emit(text) + + def on_select_item(self): + self.popup.hide() + self.setFocus() + + item = self.popup.currentItem() + self.setText(item.text()) + QMetaObject.invokeMethod(self, 'returnPressed') + self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole)) + + def fetch_models(self): + text = self.text() + if len(text) < 3: + return + + url = QUrl("https://huggingface.co/api/models") + + query = QUrlQuery() + query.addQueryItem("filter", "whisper") + query.addQueryItem("search", text) + + url.setQuery(query) + + return self.network_manager.get(QNetworkRequest(url)) + + def on_popup_selected(self): + self.timer.stop() + + def on_request_response(self, network_reply: QNetworkReply): + if network_reply.error() != QNetworkReply.NetworkError.NoError: + logging.debug('Error fetching Hugging Face models: %s', network_reply.error()) + return + + models = json.loads(network_reply.readAll().data()) + + self.popup.setUpdatesEnabled(False) + self.popup.clear() + + for model in models: + model_id = model.get('id') + + item = QListWidgetItem(self.popup) + item.setText(model_id) + item.setData(Qt.ItemDataRole.UserRole, model_id) + + self.popup.setCurrentItem(self.popup.item(0)) + self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20) + self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll + self.popup.setUpdatesEnabled(True) + self.popup.move(self.mapToGlobal(QPoint(0, self.height()))) + self.popup.setFocus() + self.popup.show() + + def eventFilter(self, target: QObject, event: QEvent): + if hasattr(self, 'popup') is False or target != self.popup: + return False + + if event.type() == QEvent.Type.MouseButtonPress: + self.popup.hide() + self.setFocus() + return True + + if isinstance(event, QKeyEvent): + key = event.key() + if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]: + if self.popup.currentItem() is not None: + self.on_select_item() + return True + + if key == Qt.Key.Key_Escape: + self.setFocus() + self.popup.hide() + return True + + if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp, + Qt.Key.Key_PageDown]: + return False + + self.setFocus() + self.event(event) + self.popup.hide() + + return False diff --git a/buzz/widgets/transcriber/languages_combo_box.py b/buzz/widgets/transcriber/languages_combo_box.py new file mode 100644 index 00000000..786a482d --- /dev/null +++ b/buzz/widgets/transcriber/languages_combo_box.py @@ -0,0 +1,31 @@ +from typing import Optional + +from PyQt6.QtCore import pyqtSignal +from PyQt6.QtWidgets import QComboBox, QWidget + +from buzz.locale import _ +from buzz.transcriber import LANGUAGES + + +class LanguagesComboBox(QComboBox): + """LanguagesComboBox displays a list of languages available to use with Whisper""" + # language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language" + languageChanged = pyqtSignal(str) + + def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None: + super().__init__(parent) + + whisper_languages = sorted( + [(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1]) + self.languages = [('', _('Detect Language'))] + whisper_languages + + self.addItems([lang[1] for lang in self.languages]) + self.currentIndexChanged.connect(self.on_index_changed) + + default_language_key = default_language if default_language != '' else None + for i, lang in enumerate(self.languages): + if lang[0] == default_language_key: + self.setCurrentIndex(i) + + def on_index_changed(self, index: int): + self.languageChanged.emit(self.languages[index][0]) diff --git a/buzz/widgets/transcriber/tasks_combo_box.py b/buzz/widgets/transcriber/tasks_combo_box.py new file mode 100644 index 00000000..2b2c02cc --- /dev/null +++ b/buzz/widgets/transcriber/tasks_combo_box.py @@ -0,0 +1,21 @@ +from typing import Optional + +from PyQt6.QtCore import pyqtSignal +from PyQt6.QtWidgets import QComboBox, QWidget + +from buzz.transcriber import Task + + +class TasksComboBox(QComboBox): + """TasksComboBox displays a list of tasks available to use with Whisper""" + taskChanged = pyqtSignal(Task) + + def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None: + super().__init__(parent, *args) + self.tasks = [i for i in Task] + self.addItems(map(lambda task: task.value.title(), self.tasks)) + self.currentIndexChanged.connect(self.on_index_changed) + self.setCurrentText(default_task.value.title()) + + def on_index_changed(self, index: int): + self.taskChanged.emit(self.tasks[index]) diff --git a/buzz/widgets/transcriber/temperature_validator.py b/buzz/widgets/transcriber/temperature_validator.py new file mode 100644 index 00000000..29986a62 --- /dev/null +++ b/buzz/widgets/transcriber/temperature_validator.py @@ -0,0 +1,19 @@ +from typing import Optional, Tuple + +from PyQt6.QtCore import QObject +from PyQt6.QtGui import QValidator + + +class TemperatureValidator(QValidator): + def __init__(self, parent: Optional[QObject] = ...) -> None: + super().__init__(parent) + + def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]: + try: + temp_strings = [temp.strip() for temp in text.split(',')] + if temp_strings[-1] == '': + return QValidator.State.Intermediate, text, cursor_position + _ = [float(temp) for temp in temp_strings] + return QValidator.State.Acceptable, text, cursor_position + except ValueError: + return QValidator.State.Invalid, text, cursor_position diff --git a/buzz/widgets/transcriber/transcription_options_group_box.py b/buzz/widgets/transcriber/transcription_options_group_box.py new file mode 100644 index 00000000..3ec1f676 --- /dev/null +++ b/buzz/widgets/transcriber/transcription_options_group_box.py @@ -0,0 +1,140 @@ +from typing import Optional, List, Tuple + +from PyQt6.QtCore import pyqtSignal +from PyQt6.QtWidgets import QGroupBox, QWidget, QFormLayout, QComboBox + +from buzz.locale import _ +from buzz.model_loader import ModelType, WhisperModelSize +from buzz.transcriber import TranscriptionOptions, Task +from buzz.widgets.model_type_combo_box import ModelTypeComboBox +from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit +from buzz.widgets.transcriber.advanced_settings_button import AdvancedSettingsButton +from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog +from buzz.widgets.transcriber.hugging_face_search_line_edit import \ + HuggingFaceSearchLineEdit +from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox +from buzz.widgets.transcriber.tasks_combo_box import TasksComboBox + + +class TranscriptionOptionsGroupBox(QGroupBox): + transcription_options: TranscriptionOptions + transcription_options_changed = pyqtSignal(TranscriptionOptions) + + def __init__( + self, + default_transcription_options: TranscriptionOptions = TranscriptionOptions(), + model_types: Optional[List[ModelType]] = None, + parent: Optional[QWidget] = None): + super().__init__(title='', parent=parent) + self.transcription_options = default_transcription_options + + self.form_layout = QFormLayout(self) + + self.tasks_combo_box = TasksComboBox( + default_task=self.transcription_options.task, + parent=self) + self.tasks_combo_box.taskChanged.connect(self.on_task_changed) + + self.languages_combo_box = LanguagesComboBox( + default_language=self.transcription_options.language, + parent=self) + self.languages_combo_box.languageChanged.connect( + self.on_language_changed) + + self.advanced_settings_button = AdvancedSettingsButton(self) + self.advanced_settings_button.clicked.connect( + self.open_advanced_settings) + + self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit() + self.hugging_face_search_line_edit.model_selected.connect( + self.on_hugging_face_model_changed) + + self.model_type_combo_box = ModelTypeComboBox(model_types=model_types, + default_model=default_transcription_options.model.model_type, + parent=self) + self.model_type_combo_box.changed.connect(self.on_model_type_changed) + + self.whisper_model_size_combo_box = QComboBox(self) + self.whisper_model_size_combo_box.addItems( + [size.value.title() for size in WhisperModelSize]) + if default_transcription_options.model.whisper_model_size is not None: + self.whisper_model_size_combo_box.setCurrentText( + default_transcription_options.model.whisper_model_size.value.title()) + self.whisper_model_size_combo_box.currentTextChanged.connect( + self.on_whisper_model_size_changed) + + self.openai_access_token_edit = OpenAIAPIKeyLineEdit( + key=default_transcription_options.openai_access_token, + parent=self) + self.openai_access_token_edit.key_changed.connect( + self.on_openai_access_token_edit_changed) + + self.form_layout.addRow(_('Model:'), self.model_type_combo_box) + self.form_layout.addRow('', self.whisper_model_size_combo_box) + self.form_layout.addRow('', self.hugging_face_search_line_edit) + self.form_layout.addRow('Access Token:', self.openai_access_token_edit) + self.form_layout.addRow(_('Task:'), self.tasks_combo_box) + self.form_layout.addRow(_('Language:'), self.languages_combo_box) + + self.reset_visible_rows() + + self.form_layout.addRow('', self.advanced_settings_button) + + self.setLayout(self.form_layout) + + def on_openai_access_token_edit_changed(self, access_token: str): + self.transcription_options.openai_access_token = access_token + self.transcription_options_changed.emit(self.transcription_options) + + def on_language_changed(self, language: str): + self.transcription_options.language = language + self.transcription_options_changed.emit(self.transcription_options) + + def on_task_changed(self, task: Task): + self.transcription_options.task = task + self.transcription_options_changed.emit(self.transcription_options) + + def on_temperature_changed(self, temperature: Tuple[float, ...]): + self.transcription_options.temperature = temperature + self.transcription_options_changed.emit(self.transcription_options) + + def on_initial_prompt_changed(self, initial_prompt: str): + self.transcription_options.initial_prompt = initial_prompt + self.transcription_options_changed.emit(self.transcription_options) + + def open_advanced_settings(self): + dialog = AdvancedSettingsDialog( + transcription_options=self.transcription_options, parent=self) + dialog.transcription_options_changed.connect( + self.on_transcription_options_changed) + dialog.exec() + + def on_transcription_options_changed(self, + transcription_options: TranscriptionOptions): + self.transcription_options = transcription_options + self.transcription_options_changed.emit(transcription_options) + + def reset_visible_rows(self): + model_type = self.transcription_options.model.model_type + self.form_layout.setRowVisible(self.hugging_face_search_line_edit, + model_type == ModelType.HUGGING_FACE) + self.form_layout.setRowVisible(self.whisper_model_size_combo_box, + (model_type == ModelType.WHISPER) or ( + model_type == ModelType.WHISPER_CPP) or ( + model_type == ModelType.FASTER_WHISPER)) + self.form_layout.setRowVisible(self.openai_access_token_edit, + model_type == ModelType.OPEN_AI_WHISPER_API) + + def on_model_type_changed(self, model_type: ModelType): + self.transcription_options.model.model_type = model_type + self.reset_visible_rows() + self.transcription_options_changed.emit(self.transcription_options) + + def on_whisper_model_size_changed(self, text: str): + model_size = WhisperModelSize(text.lower()) + self.transcription_options.model.whisper_model_size = model_size + self.transcription_options_changed.emit(self.transcription_options) + + def on_hugging_face_model_changed(self, model: str): + self.transcription_options.model.hugging_face_model_id = model + self.transcription_options_changed.emit(self.transcription_options) diff --git a/buzz/widgets/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer_widget.py index 42f46887..41cae3ac 100644 --- a/buzz/widgets/transcription_viewer_widget.py +++ b/buzz/widgets/transcription_viewer_widget.py @@ -134,10 +134,8 @@ class TranscriptionViewerWidget(QWidget): def on_menu_triggered(self, action: QAction): output_format = OutputFormat[action.text()] - default_path = get_default_output_file_path( - task=self.transcription_task.transcription_options.task, - input_file_path=self.transcription_task.file_path, - output_format=output_format) + default_path = get_default_output_file_path(task=self.transcription_task, + output_format=output_format) (output_file_path, nil) = QFileDialog.getSaveFileName(self, _('Save File'), default_path, diff --git a/tests/gui_test.py b/tests/gui_test.py index b6330e1c..92362cec 100644 --- a/tests/gui_test.py +++ b/tests/gui_test.py @@ -14,16 +14,21 @@ from pytestqt.qtbot import QtBot from buzz.__version__ import VERSION from buzz.cache import TasksCache -from buzz.gui import (AdvancedSettingsDialog, AudioDevicesComboBox, FileTranscriberWidget, - LanguagesComboBox, MainWindow, - RecordingTranscriberWidget, - TemperatureValidator, HuggingFaceSearchLineEdit, - TranscriptionOptionsGroupBox) +from buzz.gui import (AudioDevicesComboBox, MainWindow, + RecordingTranscriberWidget) +from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog +from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget +from buzz.widgets.transcriber.hugging_face_search_line_edit import \ + HuggingFaceSearchLineEdit +from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox +from buzz.widgets.transcriber.temperature_validator import TemperatureValidator from buzz.widgets.about_dialog import AboutDialog from buzz.model_loader import ModelType from buzz.settings.settings import Settings from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, TranscriptionOptions) +from buzz.widgets.transcriber.transcription_options_group_box import \ + TranscriptionOptionsGroupBox from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget from tests.mock_sounddevice import MockInputStream, mock_query_devices from .mock_qt import MockNetworkAccessManager, MockNetworkReply @@ -278,32 +283,6 @@ def clear_settings(): settings.clear() -class TestFileTranscriberWidget: - def test_should_set_window_title(self, qtbot: QtBot): - widget = FileTranscriberWidget( - file_paths=['testdata/whisper-french.mp3'], parent=None) - qtbot.add_widget(widget) - assert widget.windowTitle() == 'whisper-french.mp3' - - def test_should_emit_triggered_event(self, qtbot: QtBot): - widget = FileTranscriberWidget( - file_paths=['testdata/whisper-french.mp3'], parent=None) - qtbot.add_widget(widget) - - mock_triggered = Mock() - widget.triggered.connect(mock_triggered) - - with qtbot.wait_signal(widget.triggered, timeout=30 * 1000): - qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton) - - transcription_options, file_transcription_options, model_path = mock_triggered.call_args[ - 0][0] - assert transcription_options.language is None - assert file_transcription_options.file_paths == [ - 'testdata/whisper-french.mp3'] - assert len(model_path) > 0 - - class TestAboutDialog: def test_should_check_for_updates(self, qtbot: QtBot): reply = MockNetworkReply(data={'name': 'v' + VERSION}) diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index 09f54203..4ed8b11d 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -89,14 +89,39 @@ class TestWhisperCppFileTranscriber: class TestWhisperFileTranscriber: + @pytest.mark.parametrize( + 'output_format,expected_file_path,default_output_file_name', + [ + (OutputFormat.SRT, '/a/b/c-translate--Whisper-tiny.srt', '{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}'), + ]) + def test_default_output_file2(self, output_format: OutputFormat, expected_file_path: str, default_output_file_name: str): + file_path = get_default_output_file_path( + task=FileTranscriptionTask( + file_path='/a/b/c.mp4', + transcription_options=TranscriptionOptions(task=Task.TRANSLATE), + file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name=default_output_file_name), + model_path=''), + output_format=output_format) + assert file_path == expected_file_path + def test_default_output_file(self): srt = get_default_output_file_path( - Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.TXT) + task=FileTranscriptionTask( + file_path='/a/b/c.mp4', + transcription_options=TranscriptionOptions(task=Task.TRANSLATE), + file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'), + model_path=''), + output_format=OutputFormat.TXT) assert srt.startswith('/a/b/c (Translated on ') assert srt.endswith('.txt') srt = get_default_output_file_path( - Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.SRT) + task=FileTranscriptionTask( + file_path='/a/b/c.mp4', + transcription_options=TranscriptionOptions(task=Task.TRANSLATE), + file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'), + model_path=''), + output_format=OutputFormat.SRT) assert srt.startswith('/a/b/c (Translated on ') assert srt.endswith('.srt') diff --git a/tests/widgets/file_transcriber_widget_test.py b/tests/widgets/file_transcriber_widget_test.py new file mode 100644 index 00000000..431b4321 --- /dev/null +++ b/tests/widgets/file_transcriber_widget_test.py @@ -0,0 +1,32 @@ +from unittest.mock import Mock + +from PyQt6.QtCore import Qt +from pytestqt.qtbot import QtBot + +from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget + + +class TestFileTranscriberWidget: + def test_should_set_window_title(self, qtbot: QtBot): + widget = FileTranscriberWidget( + file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None) + qtbot.add_widget(widget) + assert widget.windowTitle() == 'whisper-french.mp3' + + def test_should_emit_triggered_event(self, qtbot: QtBot): + widget = FileTranscriberWidget( + file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None) + qtbot.add_widget(widget) + + mock_triggered = Mock() + widget.triggered.connect(mock_triggered) + + with qtbot.wait_signal(widget.triggered, timeout=30 * 1000): + qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton) + + transcription_options, file_transcription_options, model_path = mock_triggered.call_args[ + 0][0] + assert transcription_options.language is None + assert file_transcription_options.file_paths == [ + 'testdata/whisper-french.mp3'] + assert len(model_path) > 0 diff --git a/tests/widgets/preferences_dialog/general_preferences_widget_test.py b/tests/widgets/preferences_dialog/general_preferences_widget_test.py index 44d7fa82..5261fa2f 100644 --- a/tests/widgets/preferences_dialog/general_preferences_widget_test.py +++ b/tests/widgets/preferences_dialog/general_preferences_widget_test.py @@ -10,7 +10,8 @@ from buzz.widgets.preferences_dialog.general_preferences_widget import \ class TestGeneralPreferencesWidget: def test_should_disable_test_button_if_no_api_key(self, qtbot): - widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store('')) + widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(''), + default_export_file_name='') qtbot.add_widget(widget) test_button = widget.findChild(QPushButton) @@ -26,7 +27,9 @@ class TestGeneralPreferencesWidget: assert test_button.isEnabled() def test_should_test_openai_api_key(self, qtbot): - widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store('wrong-api-key')) + widget = GeneralPreferencesWidget( + keyring_store=self.get_keyring_store('wrong-api-key'), + default_export_file_name='') qtbot.add_widget(widget) test_button = widget.findChild(QPushButton) diff --git a/tests/widgets/preferences_dialog/preferences_dialog_test.py b/tests/widgets/preferences_dialog/preferences_dialog_test.py index 6478e28f..0dacb484 100644 --- a/tests/widgets/preferences_dialog/preferences_dialog_test.py +++ b/tests/widgets/preferences_dialog/preferences_dialog_test.py @@ -6,7 +6,7 @@ from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog class TestPreferencesDialog: def test_create(self, qtbot: QtBot): - dialog = PreferencesDialog(shortcuts={}) + dialog = PreferencesDialog(shortcuts={}, default_export_file_name='') qtbot.add_widget(dialog) assert dialog.windowTitle() == 'Preferences'