Add default file name setting (#559)

This commit is contained in:
Chidi Williams 2023-08-04 18:02:20 -07:00 committed by GitHub
commit f5f77b3908
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 843 additions and 620 deletions

10
buzz/dialogs.py Normal file
View file

@ -0,0 +1,10 @@
from PyQt6.QtWidgets import QWidget, QMessageBox
def show_model_download_error_dialog(parent: QWidget, error: str):
message = parent.tr(
'An error occurred while loading the Whisper model') + \
f": {error}{'' if error.endswith('.') else '.'}" + \
parent.tr("Please retry or check the application logs for more information.")
QMessageBox.critical(parent, '', message)

View file

@ -1,51 +1,44 @@
import enum
import json
import logging
import sys
from enum import auto
from typing import Dict, List, Optional, Tuple
import sounddevice
from PyQt6 import QtGui
from PyQt6.QtCore import (QObject, Qt, QThread,
QTimer, QUrl, pyqtSignal, QModelIndex, QPoint,
QUrlQuery, QMetaObject, QEvent, QThreadPool)
from PyQt6.QtCore import (Qt, QThread,
pyqtSignal, QModelIndex, QThreadPool)
from PyQt6.QtGui import (QCloseEvent, QIcon,
QKeySequence, QTextCursor, QValidator, QKeyEvent, QPainter, QColor)
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QDialogButtonBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QGroupBox,
QFormLayout,
QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)
QKeySequence, QTextCursor, QPainter, QColor)
from PyQt6.QtWidgets import (QApplication, QComboBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QFormLayout,
QSizePolicy)
from buzz.cache import TasksCache
from .__version__ import VERSION
from .action import Action
from .assets import get_asset_path
from .dialogs import show_model_download_error_dialog
from .widgets.icon import Icon, BUZZ_ICON_PATH
from .locale import _
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, \
ModelDownloader
from .paths import file_paths_as_title
from .recording import RecordingAmplitudeListener
from .settings.settings import Settings, APP_NAME
from .settings.shortcut import Shortcut
from .settings.shortcut_settings import ShortcutSettings
from .store.keyring_store import KeyringStore
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
Task,
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Task,
TranscriptionOptions,
FileTranscriptionTask, LOADED_WHISPER_DLL,
DEFAULT_WHISPER_TEMPERATURE, LANGUAGES)
DEFAULT_WHISPER_TEMPERATURE)
from .recording_transcriber import RecordingTranscriber
from .file_transcriber_queue_worker import FileTranscriberQueueWorker
from .widgets.line_edit import LineEdit
from .widgets.menu_bar import MenuBar
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from .widgets.model_type_combo_box import ModelTypeComboBox
from .widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
from .widgets.toolbar import ToolBar
from .widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
from .widgets.transcriber.transcription_options_group_box import \
TranscriptionOptionsGroupBox
from .widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
from .widgets.transcription_viewer_widget import TranscriptionViewerWidget
@ -102,45 +95,6 @@ class AudioDevicesComboBox(QComboBox):
return -1
class LanguagesComboBox(QComboBox):
"""LanguagesComboBox displays a list of languages available to use with Whisper"""
# language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language"
languageChanged = pyqtSignal(str)
def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None:
super().__init__(parent)
whisper_languages = sorted(
[(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1])
self.languages = [('', _('Detect Language'))] + whisper_languages
self.addItems([lang[1] for lang in self.languages])
self.currentIndexChanged.connect(self.on_index_changed)
default_language_key = default_language if default_language != '' else None
for i, lang in enumerate(self.languages):
if lang[0] == default_language_key:
self.setCurrentIndex(i)
def on_index_changed(self, index: int):
self.languageChanged.emit(self.languages[index][0])
class TasksComboBox(QComboBox):
"""TasksComboBox displays a list of tasks available to use with Whisper"""
taskChanged = pyqtSignal(Task)
def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None:
super().__init__(parent, *args)
self.tasks = [i for i in Task]
self.addItems(map(lambda task: task.value.title(), self.tasks))
self.currentIndexChanged.connect(self.on_index_changed)
self.setCurrentText(default_task.value.title())
def on_index_changed(self, index: int):
self.taskChanged.emit(self.tasks[index])
class TextDisplayBox(QPlainTextEdit):
"""TextDisplayBox is a read-only textbox"""
@ -164,181 +118,6 @@ class RecordButton(QPushButton):
self.setDefault(False)
def show_model_download_error_dialog(parent: QWidget, error: str):
message = parent.tr(
'An error occurred while loading the Whisper model') + \
f": {error}{'' if error.endswith('.') else '.'}" + \
parent.tr("Please retry or check the application logs for more information.")
QMessageBox.critical(parent, '', message)
class FileTranscriberWidget(QWidget):
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
model_loader: Optional[ModelDownloader] = None
file_transcription_options: FileTranscriptionOptions
transcription_options: TranscriptionOptions
is_transcribing = False
# (TranscriptionOptions, FileTranscriptionOptions, str)
triggered = pyqtSignal(tuple)
openai_access_token_changed = pyqtSignal(str)
settings = Settings()
def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None,
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
super().__init__(parent, flags)
self.setWindowTitle(file_paths_as_title(file_paths))
openai_access_token = KeyringStore().get_password(KeyringStore.Key.OPENAI_API_KEY)
self.file_paths = file_paths
default_language = self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='')
self.transcription_options = TranscriptionOptions(
openai_access_token=openai_access_token,
model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL, default_value=TranscriptionModel()),
task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE),
language=default_language if default_language != '' else None,
initial_prompt=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''),
temperature=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
default_value=DEFAULT_WHISPER_TEMPERATURE),
word_level_timings=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
default_value=False))
default_export_format_states: List[str] = self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
default_value=[])
self.file_transcription_options = FileTranscriptionOptions(
file_paths=self.file_paths,
output_formats=set([OutputFormat(output_format) for output_format in default_export_format_states]))
layout = QVBoxLayout(self)
transcription_options_group_box = TranscriptionOptionsGroupBox(
default_transcription_options=self.transcription_options, parent=self)
transcription_options_group_box.transcription_options_changed.connect(
self.on_transcription_options_changed)
self.word_level_timings_checkbox = QCheckBox(_('Word-level timings'))
self.word_level_timings_checkbox.setChecked(
self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, default_value=False))
self.word_level_timings_checkbox.stateChanged.connect(
self.on_word_level_timings_changed)
file_transcription_layout = QFormLayout()
file_transcription_layout.addRow('', self.word_level_timings_checkbox)
export_format_layout = QHBoxLayout()
for output_format in OutputFormat:
export_format_checkbox = QCheckBox(f'{output_format.value.upper()}', parent=self)
export_format_checkbox.setChecked(output_format in self.file_transcription_options.output_formats)
export_format_checkbox.stateChanged.connect(self.get_on_checkbox_state_changed_callback(output_format))
export_format_layout.addWidget(export_format_checkbox)
file_transcription_layout.addRow('Export:', export_format_layout)
self.run_button = QPushButton(_('Run'), self)
self.run_button.setDefault(True)
self.run_button.clicked.connect(self.on_click_run)
layout.addWidget(transcription_options_group_box)
layout.addLayout(file_transcription_layout)
layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout)
self.setFixedSize(self.sizeHint())
def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat):
def on_checkbox_state_changed(state: int):
if state == Qt.CheckState.Checked.value:
self.file_transcription_options.output_formats.add(output_format)
elif state == Qt.CheckState.Unchecked.value:
self.file_transcription_options.output_formats.remove(output_format)
return on_checkbox_state_changed
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
self.transcription_options = transcription_options
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE or self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API)
if self.transcription_options.openai_access_token != '':
self.openai_access_token_changed.emit(self.transcription_options.openai_access_token)
def on_click_run(self):
self.run_button.setDisabled(True)
model_path = self.transcription_options.model.get_local_model_path()
if model_path is not None:
self.on_model_loaded(model_path)
return
self.model_loader = ModelDownloader(model=self.transcription_options.model)
self.model_loader.signals.progress.connect(self.on_download_model_progress)
self.model_loader.signals.error.connect(self.on_download_model_error)
self.model_loader.signals.finished.connect(self.on_model_loaded)
QThreadPool().globalInstance().start(self.model_loader)
def on_model_loaded(self, model_path: str):
self.reset_transcriber_controls()
self.triggered.emit((self.transcription_options,
self.file_transcription_options, model_path))
self.close()
def on_download_model_progress(self, progress: Tuple[float, float]):
(current_size, total_size) = progress
if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = ModelDownloadProgressDialog(
model_type=self.transcription_options.model.model_type, parent=self)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)
def on_download_model_error(self, error: str):
self.reset_model_download()
show_model_download_error_dialog(self, error)
self.reset_transcriber_controls()
def reset_transcriber_controls(self):
self.run_button.setDisabled(False)
def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
self.model_loader.cancel()
self.reset_model_download()
def reset_model_download(self):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None
def on_word_level_timings_changed(self, value: int):
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
if self.model_loader is not None:
self.model_loader.cancel()
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, self.transcription_options.temperature)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, self.transcription_options.initial_prompt)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL, self.transcription_options.model)
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
value=self.transcription_options.word_level_timings)
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
value=[export_format.value for export_format in
self.file_transcription_options.output_formats])
super().closeEvent(event)
class AdvancedSettingsButton(QPushButton):
def __init__(self, parent: Optional[QWidget]) -> None:
super().__init__('Advanced...', parent)
class AudioMeterWidget(QWidget):
current_amplitude: float
BAR_WIDTH = 2
@ -730,6 +509,9 @@ class MainWindow(QMainWindow):
self.shortcut_settings = ShortcutSettings(settings=self.settings)
self.shortcuts = self.shortcut_settings.load()
self.default_export_file_name = self.settings.value(
Settings.Key.DEFAULT_EXPORT_FILE_NAME,
'{{ input_file_name }} ({{ task }}d on {{ date_time }})')
self.tasks = {}
self.tasks_changed.connect(self.on_tasks_changed)
@ -742,11 +524,14 @@ class MainWindow(QMainWindow):
self.addToolBar(self.toolbar)
self.setUnifiedTitleAndToolBarOnMac(True)
self.menu_bar = MenuBar(shortcuts=self.shortcuts, parent=self)
self.menu_bar = MenuBar(shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
parent=self)
self.menu_bar.import_action_triggered.connect(
self.on_new_transcription_action_triggered)
self.menu_bar.shortcuts_changed.connect(self.on_shortcuts_changed)
self.menu_bar.openai_api_key_changed.connect(self.on_openai_access_token_changed)
self.menu_bar.default_export_file_name_changed.connect(self.default_export_file_name_changed)
self.setMenuBar(self.menu_bar)
self.table_widget = TranscriptionTasksTableWidget(self)
@ -840,6 +625,7 @@ class MainWindow(QMainWindow):
def open_file_transcriber_widget(self, file_paths: List[str]):
file_transcriber_window = FileTranscriberWidget(file_paths=file_paths,
default_output_file_name=self.default_export_file_name,
parent=self,
flags=Qt.WindowType.Window)
file_transcriber_window.triggered.connect(
@ -851,6 +637,10 @@ class MainWindow(QMainWindow):
def on_openai_access_token_changed(access_token: str):
KeyringStore().set_password(KeyringStore.Key.OPENAI_API_KEY, access_token)
def default_export_file_name_changed(self, default_export_file_name: str):
self.default_export_file_name = default_export_file_name
self.settings.set_value(Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name)
def open_transcript_viewer(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
for selected_row in selected_rows:
@ -933,241 +723,6 @@ class MainWindow(QMainWindow):
super().closeEvent(event)
# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py
class HuggingFaceSearchLineEdit(LineEdit):
model_selected = pyqtSignal(str)
popup: QListWidget
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None):
super().__init__('', parent)
self.setMinimumWidth(150)
self.setPlaceholderText('openai/whisper-tiny')
self.timer = QTimer(self)
self.timer.setSingleShot(True)
self.timer.setInterval(250)
self.timer.timeout.connect(self.fetch_models)
# Restart debounce timer each time editor text changes
self.textEdited.connect(self.timer.start)
self.textEdited.connect(self.on_text_edited)
if network_access_manager is None:
network_access_manager = QNetworkAccessManager(self)
self.network_manager = network_access_manager
self.network_manager.finished.connect(self.on_request_response)
self.popup = QListWidget()
self.popup.setWindowFlags(Qt.WindowType.Popup)
self.popup.setFocusProxy(self)
self.popup.setMouseTracking(True)
self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.popup.installEventFilter(self)
self.popup.itemClicked.connect(self.on_select_item)
def on_text_edited(self, text: str):
self.model_selected.emit(text)
def on_select_item(self):
self.popup.hide()
self.setFocus()
item = self.popup.currentItem()
self.setText(item.text())
QMetaObject.invokeMethod(self, 'returnPressed')
self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole))
def fetch_models(self):
text = self.text()
if len(text) < 3:
return
url = QUrl("https://huggingface.co/api/models")
query = QUrlQuery()
query.addQueryItem("filter", "whisper")
query.addQueryItem("search", text)
url.setQuery(query)
return self.network_manager.get(QNetworkRequest(url))
def on_popup_selected(self):
self.timer.stop()
def on_request_response(self, network_reply: QNetworkReply):
if network_reply.error() != QNetworkReply.NetworkError.NoError:
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
return
models = json.loads(network_reply.readAll().data())
self.popup.setUpdatesEnabled(False)
self.popup.clear()
for model in models:
model_id = model.get('id')
item = QListWidgetItem(self.popup)
item.setText(model_id)
item.setData(Qt.ItemDataRole.UserRole, model_id)
self.popup.setCurrentItem(self.popup.item(0))
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
self.popup.setUpdatesEnabled(True)
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
self.popup.setFocus()
self.popup.show()
def eventFilter(self, target: QObject, event: QEvent):
if hasattr(self, 'popup') is False or target != self.popup:
return False
if event.type() == QEvent.Type.MouseButtonPress:
self.popup.hide()
self.setFocus()
return True
if isinstance(event, QKeyEvent):
key = event.key()
if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]:
if self.popup.currentItem() is not None:
self.on_select_item()
return True
if key == Qt.Key.Key_Escape:
self.setFocus()
self.popup.hide()
return True
if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp,
Qt.Key.Key_PageDown]:
return False
self.setFocus()
self.event(event)
self.popup.hide()
return False
class TranscriptionOptionsGroupBox(QGroupBox):
transcription_options: TranscriptionOptions
transcription_options_changed = pyqtSignal(TranscriptionOptions)
def __init__(self, default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
model_types: Optional[List[ModelType]] = None,
parent: Optional[QWidget] = None):
super().__init__(title='', parent=parent)
self.transcription_options = default_transcription_options
self.form_layout = QFormLayout(self)
self.tasks_combo_box = TasksComboBox(
default_task=self.transcription_options.task,
parent=self)
self.tasks_combo_box.taskChanged.connect(self.on_task_changed)
self.languages_combo_box = LanguagesComboBox(
default_language=self.transcription_options.language,
parent=self)
self.languages_combo_box.languageChanged.connect(
self.on_language_changed)
self.advanced_settings_button = AdvancedSettingsButton(self)
self.advanced_settings_button.clicked.connect(
self.open_advanced_settings)
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
default_model=default_transcription_options.model.model_type,
parent=self)
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
self.whisper_model_size_combo_box = QComboBox(self)
self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize])
if default_transcription_options.model.whisper_model_size is not None:
self.whisper_model_size_combo_box.setCurrentText(
default_transcription_options.model.whisper_model_size.value.title())
self.whisper_model_size_combo_box.currentTextChanged.connect(self.on_whisper_model_size_changed)
self.openai_access_token_edit = OpenAIAPIKeyLineEdit(key=default_transcription_options.openai_access_token,
parent=self)
self.openai_access_token_edit.key_changed.connect(self.on_openai_access_token_edit_changed)
self.form_layout.addRow(_('Model:'), self.model_type_combo_box)
self.form_layout.addRow('', self.whisper_model_size_combo_box)
self.form_layout.addRow('', self.hugging_face_search_line_edit)
self.form_layout.addRow('Access Token:', self.openai_access_token_edit)
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
self.reset_visible_rows()
self.form_layout.addRow('', self.advanced_settings_button)
self.setLayout(self.form_layout)
def on_openai_access_token_edit_changed(self, access_token: str):
self.transcription_options.openai_access_token = access_token
self.transcription_options_changed.emit(self.transcription_options)
def on_language_changed(self, language: str):
self.transcription_options.language = language
self.transcription_options_changed.emit(self.transcription_options)
def on_task_changed(self, task: Task):
self.transcription_options.task = task
self.transcription_options_changed.emit(self.transcription_options)
def on_temperature_changed(self, temperature: Tuple[float, ...]):
self.transcription_options.temperature = temperature
self.transcription_options_changed.emit(self.transcription_options)
def on_initial_prompt_changed(self, initial_prompt: str):
self.transcription_options.initial_prompt = initial_prompt
self.transcription_options_changed.emit(self.transcription_options)
def open_advanced_settings(self):
dialog = AdvancedSettingsDialog(
transcription_options=self.transcription_options, parent=self)
dialog.transcription_options_changed.connect(
self.on_transcription_options_changed)
dialog.exec()
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
self.transcription_options = transcription_options
self.transcription_options_changed.emit(transcription_options)
def reset_visible_rows(self):
model_type = self.transcription_options.model.model_type
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE)
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
(model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP) or (
model_type == ModelType.FASTER_WHISPER))
self.form_layout.setRowVisible(self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API)
def on_model_type_changed(self, model_type: ModelType):
self.transcription_options.model.model_type = model_type
self.reset_visible_rows()
self.transcription_options_changed.emit(self.transcription_options)
def on_whisper_model_size_changed(self, text: str):
model_size = WhisperModelSize(text.lower())
self.transcription_options.model.whisper_model_size = model_size
self.transcription_options_changed.emit(self.transcription_options)
def on_hugging_face_model_changed(self, model: str):
self.transcription_options.model.hugging_face_model_id = model
self.transcription_options_changed.emit(self.transcription_options)
class Application(QApplication):
window: MainWindow
@ -1183,73 +738,3 @@ class Application(QApplication):
def add_task(self, task: FileTranscriptionTask):
self.window.add_task(task)
class AdvancedSettingsDialog(QDialog):
transcription_options: TranscriptionOptions
transcription_options_changed = pyqtSignal(TranscriptionOptions)
def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None):
super().__init__(parent)
self.transcription_options = transcription_options
self.setWindowTitle(_('Advanced Settings'))
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
QDialogButtonBox.StandardButton.Ok), self)
button_box.accepted.connect(self.accept)
layout = QFormLayout(self)
default_temperature_text = ', '.join(
[str(temp) for temp in transcription_options.temperature])
self.temperature_line_edit = LineEdit(default_temperature_text, self)
self.temperature_line_edit.setPlaceholderText(
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"'))
self.temperature_line_edit.setMinimumWidth(170)
self.temperature_line_edit.textChanged.connect(
self.on_temperature_changed)
self.temperature_line_edit.setValidator(TemperatureValidator(self))
self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER)
self.initial_prompt_text_edit = QPlainTextEdit(
transcription_options.initial_prompt, self)
self.initial_prompt_text_edit.textChanged.connect(
self.on_initial_prompt_changed)
self.initial_prompt_text_edit.setEnabled(
transcription_options.model.model_type == ModelType.WHISPER)
layout.addRow(_('Temperature:'), self.temperature_line_edit)
layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit)
layout.addWidget(button_box)
self.setLayout(layout)
self.setFixedSize(self.sizeHint())
def on_temperature_changed(self, text: str):
try:
temperatures = [float(temp.strip()) for temp in text.split(',')]
self.transcription_options.temperature = tuple(temperatures)
self.transcription_options_changed.emit(self.transcription_options)
except ValueError:
pass
def on_initial_prompt_changed(self):
self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText()
self.transcription_options_changed.emit(self.transcription_options)
class TemperatureValidator(QValidator):
def __init__(self, parent: Optional[QObject] = ...) -> None:
super().__init__(parent)
def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]:
try:
temp_strings = [temp.strip() for temp in text.split(',')]
if temp_strings[-1] == '':
return QValidator.State.Intermediate, text, cursor_position
_ = [float(temp) for temp in temp_strings]
return QValidator.State.Acceptable, text, cursor_position
except ValueError:
return QValidator.State.Invalid, text, cursor_position

View file

@ -19,7 +19,7 @@ from platformdirs import user_cache_dir
from tqdm.auto import tqdm
class WhisperModelSize(enum.Enum):
class WhisperModelSize(str, enum.Enum):
TINY = 'tiny'
BASE = 'base'
SMALL = 'small'

View file

@ -25,14 +25,18 @@ class Settings:
FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = 'file-transcriber/word-level-timings'
FILE_TRANSCRIBER_EXPORT_FORMATS = 'file-transcriber/export-formats'
DEFAULT_EXPORT_FILE_NAME = 'transcriber/default-export-file-name'
SHORTCUTS = 'shortcuts'
def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)
def value(self, key: Key, default_value: typing.Any, value_type: typing.Optional[type] = None) -> typing.Any:
def value(self, key: Key, default_value: typing.Any,
value_type: typing.Optional[type] = None) -> typing.Any:
return self.settings.value(key.value, default_value,
value_type if value_type is not None else type(default_value))
value_type if value_type is not None else type(
default_value))
def clear(self):
self.settings.clear()

View file

@ -66,13 +66,15 @@ class TranscriptionOptions:
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
openai_access_token: str = field(default='', metadata=config(exclude=Exclude.ALWAYS))
openai_access_token: str = field(default='',
metadata=config(exclude=Exclude.ALWAYS))
@dataclass()
class FileTranscriptionOptions:
file_paths: List[str]
output_formats: Set['OutputFormat'] = field(default_factory=set)
default_output_file_name: str = ''
@dataclass_json
@ -127,12 +129,11 @@ class FileTranscriber(QObject):
self.completed.emit(segments)
for output_format in self.transcription_task.file_transcription_options.output_formats:
default_path = get_default_output_file_path(
task=self.transcription_task.transcription_options.task,
input_file_path=self.transcription_task.file_path,
output_format=output_format)
default_path = get_default_output_file_path(task=self.transcription_task,
output_format=output_format)
write_output(path=default_path, segments=segments, output_format=output_format)
write_output(path=default_path, segments=segments,
output_format=output_format)
@abstractmethod
def transcribe(self) -> List[Segment]:
@ -172,17 +173,22 @@ class WhisperCppFileTranscriber(FileTranscriber):
logging.debug(
'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, '
'word level timings = %s',
self.file_path, self.language, self.task, model_path, self.word_level_timings)
self.file_path, self.language, self.task, model_path,
self.word_level_timings)
audio = whisper.audio.load_audio(self.file_path)
self.duration_audio_ms = len(audio) * 1000 / whisper.audio.SAMPLE_RATE
whisper_params = whisper_cpp_params(language=self.language if self.language is not None else '', task=self.task,
word_level_timings=self.word_level_timings)
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(id(self.state))
whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(self.encoder_begin_callback)
whisper_params = whisper_cpp_params(
language=self.language if self.language is not None else '', task=self.task,
word_level_timings=self.word_level_timings)
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(
id(self.state))
whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(
self.encoder_begin_callback)
whisper_params.new_segment_callback_user_data = ctypes.c_void_p(id(self.state))
whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(self.new_segment_callback)
whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(
self.new_segment_callback)
model = WhisperCpp(model=model_path)
result = model.transcribe(audio=self.file_path, params=whisper_params)
@ -199,13 +205,15 @@ class WhisperCppFileTranscriber(FileTranscriber):
# t1 seems to sometimes be larger than the duration when the
# audio ends in silence. Trim to fix the displayed progress.
progress = min(t1 * 10, self.duration_audio_ms)
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, ctypes.py_object).value
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
ctypes.py_object).value
if state.running:
self.progress.emit((progress, self.duration_audio_ms))
@staticmethod
def encoder_begin_callback(_ctx, _state, user_data):
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, ctypes.py_object).value
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
ctypes.py_object).value
return state.running == 1
def stop(self):
@ -219,8 +227,10 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
self.task = task.transcription_options.task
def transcribe(self) -> List[Segment]:
logging.debug('Starting OpenAI Whisper API file transcription, file path = %s, task = %s', self.file_path,
self.task)
logging.debug(
'Starting OpenAI Whisper API file transcription, file path = %s, task = %s',
self.file_path,
self.task)
wav_file = tempfile.mktemp() + '.wav'
(
@ -235,14 +245,18 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
language = self.transcription_task.transcription_options.language
response_format = "verbose_json"
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
transcript = openai.Audio.translate("whisper-1", audio_file, response_format=response_format,
transcript = openai.Audio.translate("whisper-1", audio_file,
response_format=response_format,
language=language)
else:
transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format=response_format,
transcript = openai.Audio.transcribe("whisper-1", audio_file,
response_format=response_format,
language=language)
segments = [Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for segment in
transcript["segments"]]
segments = [
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for
segment in
transcript["segments"]]
return segments
def stop(self):
@ -273,7 +287,8 @@ class WhisperFileTranscriber(FileTranscriber):
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
self.current_process = multiprocessing.Process(target=self.transcribe_whisper,
args=(send_pipe, self.transcription_task))
args=(send_pipe,
self.transcription_task))
if not self.stopped:
self.current_process.start()
self.started_process = True
@ -291,7 +306,8 @@ class WhisperFileTranscriber(FileTranscriber):
logging.debug(
'whisper process completed with code = %s, time taken = %s, number of segments = %s',
self.current_process.exitcode, datetime.datetime.now() - time_started, len(self.segments))
self.current_process.exitcode, datetime.datetime.now() - time_started,
len(self.segments))
if self.current_process.exitcode != 0:
raise Exception('Unknown error')
@ -299,7 +315,8 @@ class WhisperFileTranscriber(FileTranscriber):
return self.segments
@classmethod
def transcribe_whisper(cls, stderr_conn: Connection, task: FileTranscriptionTask) -> None:
def transcribe_whisper(cls, stderr_conn: Connection,
task: FileTranscriptionTask) -> None:
with pipe_stderr(stderr_conn):
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
segments = cls.transcribe_hugging_face(task)
@ -308,7 +325,8 @@ class WhisperFileTranscriber(FileTranscriber):
elif task.transcription_options.model.model_type == ModelType.WHISPER:
segments = cls.transcribe_openai_whisper(task)
else:
raise Exception(f"Invalid model type: {task.transcription_options.model.model_type}")
raise Exception(
f"Invalid model type: {task.transcription_options.model.model_type}")
segments_json = json.dumps(
segments, ensure_ascii=True, default=vars)
@ -321,7 +339,8 @@ class WhisperFileTranscriber(FileTranscriber):
model = transformers_whisper.load_model(task.model_path)
language = task.transcription_options.language if task.transcription_options.language is not None else 'en'
result = model.transcribe(audio=task.file_path, language=language,
task=task.transcription_options.task.value, verbose=False)
task=task.transcription_options.task.value,
verbose=False)
return [
Segment(
start=int(segment.get('start') * 1000),
@ -368,7 +387,8 @@ class WhisperFileTranscriber(FileTranscriber):
stable_whisper.modify_model(model)
result = model.transcribe(
audio=task.file_path, language=task.transcription_options.language,
task=task.transcription_options.task.value, temperature=task.transcription_options.temperature,
task=task.transcription_options.task.value,
temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt, pbar=True)
segments = stable_whisper.group_word_timestamps(result)
return [Segment(
@ -423,7 +443,8 @@ class WhisperFileTranscriber(FileTranscriber):
def write_output(path: str, segments: List[Segment], output_format: OutputFormat):
logging.debug(
'Writing transcription output, path = %s, output format = %s, number of segments = %s', path, output_format,
'Writing transcription output, path = %s, output format = %s, number of segments = %s',
path, output_format,
len(segments))
with open(path, 'w', encoding='utf-8') as file:
@ -473,8 +494,22 @@ SUPPORTED_OUTPUT_FORMATS = 'Audio files (*.mp3 *.wav *.m4a *.ogg);;\
Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)'
def get_default_output_file_path(task: Task, input_file_path: str, output_format: OutputFormat):
return f'{os.path.splitext(input_file_path)[0]} ({task.value.title()}d on {datetime.datetime.now():%d-%b-%Y %H-%M-%S}).{output_format.value}'
def get_default_output_file_path(task: FileTranscriptionTask,
output_format: OutputFormat):
input_file_name = os.path.splitext(task.file_path)[0]
date_time_now = datetime.datetime.now().strftime('%d-%b-%Y %H-%M-%S')
return (task.file_transcription_options.default_output_file_name
.replace('{{ input_file_name }}', input_file_name)
.replace('{{ task }}', task.transcription_options.task.value)
.replace('{{ language }}', task.transcription_options.language or '')
.replace('{{ model_type }}',
task.transcription_options.model.model_type.value)
.replace('{{ model_size }}',
task.transcription_options.model.whisper_model_size.value if
task.transcription_options.model.whisper_model_size is not None else
'')
.replace('{{ date_time }}', date_time_now)
+ f".{output_format.value}")
def whisper_cpp_params(

View file

@ -1,3 +1,4 @@
import webbrowser
from typing import Dict
from PyQt6.QtCore import pyqtSignal
@ -15,11 +16,14 @@ class MenuBar(QMenuBar):
import_action_triggered = pyqtSignal()
shortcuts_changed = pyqtSignal(dict)
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)
def __init__(self, shortcuts: Dict[str, str], parent: QWidget):
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
parent: QWidget):
super().__init__(parent)
self.shortcuts = shortcuts
self.default_export_file_name = default_export_file_name
self.import_action = QAction(_("Import Media File..."), self)
self.import_action.triggered.connect(
@ -31,6 +35,9 @@ class MenuBar(QMenuBar):
self.preferences_action = QAction(_("Preferences..."), self)
self.preferences_action.triggered.connect(self.on_preferences_action_triggered)
help_action = QAction(f'{_("Help")}', self)
help_action.triggered.connect(self.on_help_action_triggered)
self.set_shortcuts(shortcuts)
file_menu = self.addMenu(_("File"))
@ -38,6 +45,7 @@ class MenuBar(QMenuBar):
help_menu = self.addMenu(_("Help"))
help_menu.addAction(about_action)
help_menu.addAction(help_action)
help_menu.addAction(self.preferences_action)
def on_import_action_triggered(self):
@ -48,11 +56,18 @@ class MenuBar(QMenuBar):
about_dialog.open()
def on_preferences_action_triggered(self):
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts, parent=self)
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
parent=self)
preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed)
preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed)
preferences_dialog.default_export_file_name_changed.connect(
self.default_export_file_name_changed)
preferences_dialog.open()
def on_help_action_triggered(self):
webbrowser.open('https://chidiwilliams.github.io/buzz/docs')
def set_shortcuts(self, shortcuts: Dict[str, str]):
self.shortcuts = shortcuts

View file

@ -3,33 +3,45 @@ from typing import Optional
import openai
from PyQt6.QtCore import QRunnable, QObject, pyqtSignal, QThreadPool
from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox
from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox, QLineEdit
from openai.error import AuthenticationError
from buzz.store.keyring_store import KeyringStore
from buzz.widgets.line_edit import LineEdit
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
class GeneralPreferencesWidget(QWidget):
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)
def __init__(self, keyring_store=KeyringStore(), parent: Optional[QWidget] = None):
def __init__(self, default_export_file_name: str, keyring_store=KeyringStore(),
parent: Optional[QWidget] = None):
super().__init__(parent)
self.openai_api_key = keyring_store.get_password(KeyringStore.Key.OPENAI_API_KEY)
self.openai_api_key = keyring_store.get_password(
KeyringStore.Key.OPENAI_API_KEY)
layout = QFormLayout(self)
self.openai_api_key_line_edit = OpenAIAPIKeyLineEdit(self.openai_api_key, self)
self.openai_api_key_line_edit.key_changed.connect(self.on_openai_api_key_changed)
self.openai_api_key_line_edit.key_changed.connect(
self.on_openai_api_key_changed)
self.test_openai_api_key_button = QPushButton('Test')
self.test_openai_api_key_button.clicked.connect(self.on_click_test_openai_api_key_button)
self.test_openai_api_key_button.clicked.connect(
self.on_click_test_openai_api_key_button)
self.update_test_openai_api_key_button()
layout.addRow('OpenAI API Key', self.openai_api_key_line_edit)
layout.addRow('', self.test_openai_api_key_button)
default_export_file_name_line_edit = LineEdit(default_export_file_name, self)
default_export_file_name_line_edit.textChanged.connect(
self.default_export_file_name_changed)
default_export_file_name_line_edit.setMinimumWidth(200)
layout.addRow('Default export file name', default_export_file_name_line_edit)
self.setLayout(layout)
def update_test_openai_api_key_button(self):

View file

@ -15,8 +15,9 @@ from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import
class PreferencesDialog(QDialog):
shortcuts_changed = pyqtSignal(dict)
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)
def __init__(self, shortcuts: Dict[str, str],
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
parent: Optional[QWidget] = None) -> None:
super().__init__(parent)
@ -25,8 +26,11 @@ class PreferencesDialog(QDialog):
layout = QVBoxLayout(self)
tab_widget = QTabWidget(self)
general_tab_widget = GeneralPreferencesWidget(parent=self)
general_tab_widget = GeneralPreferencesWidget(
default_export_file_name=default_export_file_name, parent=self)
general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed)
general_tab_widget.default_export_file_name_changed.connect(
self.default_export_file_name_changed)
tab_widget.addTab(general_tab_widget, _('General'))
models_tab_widget = ModelsPreferencesWidget(parent=self)

View file

View file

@ -0,0 +1,8 @@
from typing import Optional
from PyQt6.QtWidgets import QPushButton, QWidget
class AdvancedSettingsButton(QPushButton):
def __init__(self, parent: Optional[QWidget]) -> None:
super().__init__('Advanced...', parent)

View file

@ -0,0 +1,64 @@
from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QDialog, QWidget, QDialogButtonBox, QFormLayout, \
QPlainTextEdit
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
from buzz.locale import _
from buzz.model_loader import ModelType
from buzz.transcriber import TranscriptionOptions
from buzz.widgets.line_edit import LineEdit
class AdvancedSettingsDialog(QDialog):
transcription_options: TranscriptionOptions
transcription_options_changed = pyqtSignal(TranscriptionOptions)
def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None):
super().__init__(parent)
self.transcription_options = transcription_options
self.setWindowTitle(_('Advanced Settings'))
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
QDialogButtonBox.StandardButton.Ok), self)
button_box.accepted.connect(self.accept)
layout = QFormLayout(self)
default_temperature_text = ', '.join(
[str(temp) for temp in transcription_options.temperature])
self.temperature_line_edit = LineEdit(default_temperature_text, self)
self.temperature_line_edit.setPlaceholderText(
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"'))
self.temperature_line_edit.setMinimumWidth(170)
self.temperature_line_edit.textChanged.connect(
self.on_temperature_changed)
self.temperature_line_edit.setValidator(TemperatureValidator(self))
self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER)
self.initial_prompt_text_edit = QPlainTextEdit(
transcription_options.initial_prompt, self)
self.initial_prompt_text_edit.textChanged.connect(
self.on_initial_prompt_changed)
self.initial_prompt_text_edit.setEnabled(
transcription_options.model.model_type == ModelType.WHISPER)
layout.addRow(_('Temperature:'), self.temperature_line_edit)
layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit)
layout.addWidget(button_box)
self.setLayout(layout)
self.setFixedSize(self.sizeHint())
def on_temperature_changed(self, text: str):
try:
temperatures = [float(temp.strip()) for temp in text.split(',')]
self.transcription_options.temperature = tuple(temperatures)
self.transcription_options_changed.emit(self.transcription_options)
except ValueError:
pass
def on_initial_prompt_changed(self):
self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText()
self.transcription_options_changed.emit(self.transcription_options)

View file

@ -0,0 +1,204 @@
from typing import Optional, List, Tuple
from PyQt6 import QtGui
from PyQt6.QtCore import pyqtSignal, Qt, QThreadPool
from PyQt6.QtWidgets import QWidget, QVBoxLayout, QCheckBox, QFormLayout, QHBoxLayout, \
QPushButton
from buzz.dialogs import show_model_download_error_dialog
from buzz.locale import _
from buzz.model_loader import ModelDownloader, TranscriptionModel, ModelType
from buzz.paths import file_paths_as_title
from buzz.settings.settings import Settings
from buzz.store.keyring_store import KeyringStore
from buzz.transcriber import FileTranscriptionOptions, TranscriptionOptions, Task, \
DEFAULT_WHISPER_TEMPERATURE, OutputFormat
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from buzz.widgets.transcriber.transcription_options_group_box import \
TranscriptionOptionsGroupBox
class FileTranscriberWidget(QWidget):
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
model_loader: Optional[ModelDownloader] = None
file_transcription_options: FileTranscriptionOptions
transcription_options: TranscriptionOptions
is_transcribing = False
# (TranscriptionOptions, FileTranscriptionOptions, str)
triggered = pyqtSignal(tuple)
openai_access_token_changed = pyqtSignal(str)
settings = Settings()
def __init__(self, file_paths: List[str],
default_output_file_name: str,
parent: Optional[QWidget] = None,
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
super().__init__(parent, flags)
self.setWindowTitle(file_paths_as_title(file_paths))
openai_access_token = KeyringStore().get_password(
KeyringStore.Key.OPENAI_API_KEY)
self.file_paths = file_paths
default_language = self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='')
self.transcription_options = TranscriptionOptions(
openai_access_token=openai_access_token,
model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL,
default_value=TranscriptionModel()),
task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK,
default_value=Task.TRANSCRIBE),
language=default_language if default_language != '' else None,
initial_prompt=self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''),
temperature=self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
default_value=DEFAULT_WHISPER_TEMPERATURE),
word_level_timings=self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
default_value=False))
default_export_format_states: List[str] = self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
default_value=[])
self.file_transcription_options = FileTranscriptionOptions(
file_paths=self.file_paths,
output_formats=set([OutputFormat(output_format) for output_format in
default_export_format_states]),
default_output_file_name=default_output_file_name)
layout = QVBoxLayout(self)
transcription_options_group_box = TranscriptionOptionsGroupBox(
default_transcription_options=self.transcription_options, parent=self)
transcription_options_group_box.transcription_options_changed.connect(
self.on_transcription_options_changed)
self.word_level_timings_checkbox = QCheckBox(_('Word-level timings'))
self.word_level_timings_checkbox.setChecked(
self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
default_value=False))
self.word_level_timings_checkbox.stateChanged.connect(
self.on_word_level_timings_changed)
file_transcription_layout = QFormLayout()
file_transcription_layout.addRow('', self.word_level_timings_checkbox)
export_format_layout = QHBoxLayout()
for output_format in OutputFormat:
export_format_checkbox = QCheckBox(f'{output_format.value.upper()}',
parent=self)
export_format_checkbox.setChecked(
output_format in self.file_transcription_options.output_formats)
export_format_checkbox.stateChanged.connect(
self.get_on_checkbox_state_changed_callback(output_format))
export_format_layout.addWidget(export_format_checkbox)
file_transcription_layout.addRow('Export:', export_format_layout)
self.run_button = QPushButton(_('Run'), self)
self.run_button.setDefault(True)
self.run_button.clicked.connect(self.on_click_run)
layout.addWidget(transcription_options_group_box)
layout.addLayout(file_transcription_layout)
layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout)
self.setFixedSize(self.sizeHint())
def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat):
def on_checkbox_state_changed(state: int):
if state == Qt.CheckState.Checked.value:
self.file_transcription_options.output_formats.add(output_format)
elif state == Qt.CheckState.Unchecked.value:
self.file_transcription_options.output_formats.remove(output_format)
return on_checkbox_state_changed
def on_transcription_options_changed(self,
transcription_options: TranscriptionOptions):
self.transcription_options = transcription_options
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE or
self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API)
if self.transcription_options.openai_access_token != '':
self.openai_access_token_changed.emit(
self.transcription_options.openai_access_token)
def on_click_run(self):
self.run_button.setDisabled(True)
model_path = self.transcription_options.model.get_local_model_path()
if model_path is not None:
self.on_model_loaded(model_path)
return
self.model_loader = ModelDownloader(model=self.transcription_options.model)
self.model_loader.signals.progress.connect(self.on_download_model_progress)
self.model_loader.signals.error.connect(self.on_download_model_error)
self.model_loader.signals.finished.connect(self.on_model_loaded)
QThreadPool().globalInstance().start(self.model_loader)
def on_model_loaded(self, model_path: str):
self.reset_transcriber_controls()
self.triggered.emit((self.transcription_options,
self.file_transcription_options, model_path))
self.close()
def on_download_model_progress(self, progress: Tuple[float, float]):
(current_size, total_size) = progress
if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = ModelDownloadProgressDialog(
model_type=self.transcription_options.model.model_type, parent=self)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_value(
fraction_completed=current_size / total_size)
def on_download_model_error(self, error: str):
self.reset_model_download()
show_model_download_error_dialog(self, error)
self.reset_transcriber_controls()
def reset_transcriber_controls(self):
self.run_button.setDisabled(False)
def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
self.model_loader.cancel()
self.reset_model_download()
def reset_model_download(self):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None
def on_word_level_timings_changed(self, value: int):
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
if self.model_loader is not None:
self.model_loader.cancel()
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE,
self.transcription_options.language)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK,
self.transcription_options.task)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
self.transcription_options.temperature)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT,
self.transcription_options.initial_prompt)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL,
self.transcription_options.model)
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
value=self.transcription_options.word_level_timings)
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
value=[export_format.value for export_format in
self.file_transcription_options.output_formats])
super().closeEvent(event)

View file

@ -0,0 +1,134 @@
import json
import logging
from typing import Optional
from PyQt6.QtCore import pyqtSignal, QTimer, Qt, QMetaObject, QUrl, QUrlQuery, QPoint, \
QObject, QEvent
from PyQt6.QtGui import QKeyEvent
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem
from buzz.widgets.line_edit import LineEdit
# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py
class HuggingFaceSearchLineEdit(LineEdit):
model_selected = pyqtSignal(str)
popup: QListWidget
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None):
super().__init__('', parent)
self.setMinimumWidth(150)
self.setPlaceholderText('openai/whisper-tiny')
self.timer = QTimer(self)
self.timer.setSingleShot(True)
self.timer.setInterval(250)
self.timer.timeout.connect(self.fetch_models)
# Restart debounce timer each time editor text changes
self.textEdited.connect(self.timer.start)
self.textEdited.connect(self.on_text_edited)
if network_access_manager is None:
network_access_manager = QNetworkAccessManager(self)
self.network_manager = network_access_manager
self.network_manager.finished.connect(self.on_request_response)
self.popup = QListWidget()
self.popup.setWindowFlags(Qt.WindowType.Popup)
self.popup.setFocusProxy(self)
self.popup.setMouseTracking(True)
self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.popup.installEventFilter(self)
self.popup.itemClicked.connect(self.on_select_item)
def on_text_edited(self, text: str):
self.model_selected.emit(text)
def on_select_item(self):
self.popup.hide()
self.setFocus()
item = self.popup.currentItem()
self.setText(item.text())
QMetaObject.invokeMethod(self, 'returnPressed')
self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole))
def fetch_models(self):
text = self.text()
if len(text) < 3:
return
url = QUrl("https://huggingface.co/api/models")
query = QUrlQuery()
query.addQueryItem("filter", "whisper")
query.addQueryItem("search", text)
url.setQuery(query)
return self.network_manager.get(QNetworkRequest(url))
def on_popup_selected(self):
self.timer.stop()
def on_request_response(self, network_reply: QNetworkReply):
if network_reply.error() != QNetworkReply.NetworkError.NoError:
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
return
models = json.loads(network_reply.readAll().data())
self.popup.setUpdatesEnabled(False)
self.popup.clear()
for model in models:
model_id = model.get('id')
item = QListWidgetItem(self.popup)
item.setText(model_id)
item.setData(Qt.ItemDataRole.UserRole, model_id)
self.popup.setCurrentItem(self.popup.item(0))
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
self.popup.setUpdatesEnabled(True)
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
self.popup.setFocus()
self.popup.show()
def eventFilter(self, target: QObject, event: QEvent):
if hasattr(self, 'popup') is False or target != self.popup:
return False
if event.type() == QEvent.Type.MouseButtonPress:
self.popup.hide()
self.setFocus()
return True
if isinstance(event, QKeyEvent):
key = event.key()
if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]:
if self.popup.currentItem() is not None:
self.on_select_item()
return True
if key == Qt.Key.Key_Escape:
self.setFocus()
self.popup.hide()
return True
if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp,
Qt.Key.Key_PageDown]:
return False
self.setFocus()
self.event(event)
self.popup.hide()
return False

View file

@ -0,0 +1,31 @@
from typing import Optional
from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QComboBox, QWidget
from buzz.locale import _
from buzz.transcriber import LANGUAGES
class LanguagesComboBox(QComboBox):
"""LanguagesComboBox displays a list of languages available to use with Whisper"""
# language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language"
languageChanged = pyqtSignal(str)
def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None:
super().__init__(parent)
whisper_languages = sorted(
[(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1])
self.languages = [('', _('Detect Language'))] + whisper_languages
self.addItems([lang[1] for lang in self.languages])
self.currentIndexChanged.connect(self.on_index_changed)
default_language_key = default_language if default_language != '' else None
for i, lang in enumerate(self.languages):
if lang[0] == default_language_key:
self.setCurrentIndex(i)
def on_index_changed(self, index: int):
self.languageChanged.emit(self.languages[index][0])

View file

@ -0,0 +1,21 @@
from typing import Optional
from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QComboBox, QWidget
from buzz.transcriber import Task
class TasksComboBox(QComboBox):
"""TasksComboBox displays a list of tasks available to use with Whisper"""
taskChanged = pyqtSignal(Task)
def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None:
super().__init__(parent, *args)
self.tasks = [i for i in Task]
self.addItems(map(lambda task: task.value.title(), self.tasks))
self.currentIndexChanged.connect(self.on_index_changed)
self.setCurrentText(default_task.value.title())
def on_index_changed(self, index: int):
self.taskChanged.emit(self.tasks[index])

View file

@ -0,0 +1,19 @@
from typing import Optional, Tuple
from PyQt6.QtCore import QObject
from PyQt6.QtGui import QValidator
class TemperatureValidator(QValidator):
def __init__(self, parent: Optional[QObject] = ...) -> None:
super().__init__(parent)
def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]:
try:
temp_strings = [temp.strip() for temp in text.split(',')]
if temp_strings[-1] == '':
return QValidator.State.Intermediate, text, cursor_position
_ = [float(temp) for temp in temp_strings]
return QValidator.State.Acceptable, text, cursor_position
except ValueError:
return QValidator.State.Invalid, text, cursor_position

View file

@ -0,0 +1,140 @@
from typing import Optional, List, Tuple
from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QGroupBox, QWidget, QFormLayout, QComboBox
from buzz.locale import _
from buzz.model_loader import ModelType, WhisperModelSize
from buzz.transcriber import TranscriptionOptions, Task
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
from buzz.widgets.transcriber.advanced_settings_button import AdvancedSettingsButton
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
from buzz.widgets.transcriber.hugging_face_search_line_edit import \
HuggingFaceSearchLineEdit
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
from buzz.widgets.transcriber.tasks_combo_box import TasksComboBox
class TranscriptionOptionsGroupBox(QGroupBox):
transcription_options: TranscriptionOptions
transcription_options_changed = pyqtSignal(TranscriptionOptions)
def __init__(
self,
default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
model_types: Optional[List[ModelType]] = None,
parent: Optional[QWidget] = None):
super().__init__(title='', parent=parent)
self.transcription_options = default_transcription_options
self.form_layout = QFormLayout(self)
self.tasks_combo_box = TasksComboBox(
default_task=self.transcription_options.task,
parent=self)
self.tasks_combo_box.taskChanged.connect(self.on_task_changed)
self.languages_combo_box = LanguagesComboBox(
default_language=self.transcription_options.language,
parent=self)
self.languages_combo_box.languageChanged.connect(
self.on_language_changed)
self.advanced_settings_button = AdvancedSettingsButton(self)
self.advanced_settings_button.clicked.connect(
self.open_advanced_settings)
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
self.hugging_face_search_line_edit.model_selected.connect(
self.on_hugging_face_model_changed)
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
default_model=default_transcription_options.model.model_type,
parent=self)
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
self.whisper_model_size_combo_box = QComboBox(self)
self.whisper_model_size_combo_box.addItems(
[size.value.title() for size in WhisperModelSize])
if default_transcription_options.model.whisper_model_size is not None:
self.whisper_model_size_combo_box.setCurrentText(
default_transcription_options.model.whisper_model_size.value.title())
self.whisper_model_size_combo_box.currentTextChanged.connect(
self.on_whisper_model_size_changed)
self.openai_access_token_edit = OpenAIAPIKeyLineEdit(
key=default_transcription_options.openai_access_token,
parent=self)
self.openai_access_token_edit.key_changed.connect(
self.on_openai_access_token_edit_changed)
self.form_layout.addRow(_('Model:'), self.model_type_combo_box)
self.form_layout.addRow('', self.whisper_model_size_combo_box)
self.form_layout.addRow('', self.hugging_face_search_line_edit)
self.form_layout.addRow('Access Token:', self.openai_access_token_edit)
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
self.reset_visible_rows()
self.form_layout.addRow('', self.advanced_settings_button)
self.setLayout(self.form_layout)
def on_openai_access_token_edit_changed(self, access_token: str):
self.transcription_options.openai_access_token = access_token
self.transcription_options_changed.emit(self.transcription_options)
def on_language_changed(self, language: str):
self.transcription_options.language = language
self.transcription_options_changed.emit(self.transcription_options)
def on_task_changed(self, task: Task):
self.transcription_options.task = task
self.transcription_options_changed.emit(self.transcription_options)
def on_temperature_changed(self, temperature: Tuple[float, ...]):
self.transcription_options.temperature = temperature
self.transcription_options_changed.emit(self.transcription_options)
def on_initial_prompt_changed(self, initial_prompt: str):
self.transcription_options.initial_prompt = initial_prompt
self.transcription_options_changed.emit(self.transcription_options)
def open_advanced_settings(self):
dialog = AdvancedSettingsDialog(
transcription_options=self.transcription_options, parent=self)
dialog.transcription_options_changed.connect(
self.on_transcription_options_changed)
dialog.exec()
def on_transcription_options_changed(self,
transcription_options: TranscriptionOptions):
self.transcription_options = transcription_options
self.transcription_options_changed.emit(transcription_options)
def reset_visible_rows(self):
model_type = self.transcription_options.model.model_type
self.form_layout.setRowVisible(self.hugging_face_search_line_edit,
model_type == ModelType.HUGGING_FACE)
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
(model_type == ModelType.WHISPER) or (
model_type == ModelType.WHISPER_CPP) or (
model_type == ModelType.FASTER_WHISPER))
self.form_layout.setRowVisible(self.openai_access_token_edit,
model_type == ModelType.OPEN_AI_WHISPER_API)
def on_model_type_changed(self, model_type: ModelType):
self.transcription_options.model.model_type = model_type
self.reset_visible_rows()
self.transcription_options_changed.emit(self.transcription_options)
def on_whisper_model_size_changed(self, text: str):
model_size = WhisperModelSize(text.lower())
self.transcription_options.model.whisper_model_size = model_size
self.transcription_options_changed.emit(self.transcription_options)
def on_hugging_face_model_changed(self, model: str):
self.transcription_options.model.hugging_face_model_id = model
self.transcription_options_changed.emit(self.transcription_options)

View file

@ -134,10 +134,8 @@ class TranscriptionViewerWidget(QWidget):
def on_menu_triggered(self, action: QAction):
output_format = OutputFormat[action.text()]
default_path = get_default_output_file_path(
task=self.transcription_task.transcription_options.task,
input_file_path=self.transcription_task.file_path,
output_format=output_format)
default_path = get_default_output_file_path(task=self.transcription_task,
output_format=output_format)
(output_file_path, nil) = QFileDialog.getSaveFileName(self, _('Save File'),
default_path,

View file

@ -14,16 +14,21 @@ from pytestqt.qtbot import QtBot
from buzz.__version__ import VERSION
from buzz.cache import TasksCache
from buzz.gui import (AdvancedSettingsDialog, AudioDevicesComboBox, FileTranscriberWidget,
LanguagesComboBox, MainWindow,
RecordingTranscriberWidget,
TemperatureValidator, HuggingFaceSearchLineEdit,
TranscriptionOptionsGroupBox)
from buzz.gui import (AudioDevicesComboBox, MainWindow,
RecordingTranscriberWidget)
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
from buzz.widgets.transcriber.hugging_face_search_line_edit import \
HuggingFaceSearchLineEdit
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
from buzz.widgets.about_dialog import AboutDialog
from buzz.model_loader import ModelType
from buzz.settings.settings import Settings
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
TranscriptionOptions)
from buzz.widgets.transcriber.transcription_options_group_box import \
TranscriptionOptionsGroupBox
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
from tests.mock_sounddevice import MockInputStream, mock_query_devices
from .mock_qt import MockNetworkAccessManager, MockNetworkReply
@ -278,32 +283,6 @@ def clear_settings():
settings.clear()
class TestFileTranscriberWidget:
def test_should_set_window_title(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], parent=None)
qtbot.add_widget(widget)
assert widget.windowTitle() == 'whisper-french.mp3'
def test_should_emit_triggered_event(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], parent=None)
qtbot.add_widget(widget)
mock_triggered = Mock()
widget.triggered.connect(mock_triggered)
with qtbot.wait_signal(widget.triggered, timeout=30 * 1000):
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
0][0]
assert transcription_options.language is None
assert file_transcription_options.file_paths == [
'testdata/whisper-french.mp3']
assert len(model_path) > 0
class TestAboutDialog:
def test_should_check_for_updates(self, qtbot: QtBot):
reply = MockNetworkReply(data={'name': 'v' + VERSION})

View file

@ -89,14 +89,39 @@ class TestWhisperCppFileTranscriber:
class TestWhisperFileTranscriber:
@pytest.mark.parametrize(
'output_format,expected_file_path,default_output_file_name',
[
(OutputFormat.SRT, '/a/b/c-translate--Whisper-tiny.srt', '{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}'),
])
def test_default_output_file2(self, output_format: OutputFormat, expected_file_path: str, default_output_file_name: str):
file_path = get_default_output_file_path(
task=FileTranscriptionTask(
file_path='/a/b/c.mp4',
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name=default_output_file_name),
model_path=''),
output_format=output_format)
assert file_path == expected_file_path
def test_default_output_file(self):
srt = get_default_output_file_path(
Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.TXT)
task=FileTranscriptionTask(
file_path='/a/b/c.mp4',
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'),
model_path=''),
output_format=OutputFormat.TXT)
assert srt.startswith('/a/b/c (Translated on ')
assert srt.endswith('.txt')
srt = get_default_output_file_path(
Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.SRT)
task=FileTranscriptionTask(
file_path='/a/b/c.mp4',
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'),
model_path=''),
output_format=OutputFormat.SRT)
assert srt.startswith('/a/b/c (Translated on ')
assert srt.endswith('.srt')

View file

@ -0,0 +1,32 @@
from unittest.mock import Mock
from PyQt6.QtCore import Qt
from pytestqt.qtbot import QtBot
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
class TestFileTranscriberWidget:
def test_should_set_window_title(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None)
qtbot.add_widget(widget)
assert widget.windowTitle() == 'whisper-french.mp3'
def test_should_emit_triggered_event(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None)
qtbot.add_widget(widget)
mock_triggered = Mock()
widget.triggered.connect(mock_triggered)
with qtbot.wait_signal(widget.triggered, timeout=30 * 1000):
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
0][0]
assert transcription_options.language is None
assert file_transcription_options.file_paths == [
'testdata/whisper-french.mp3']
assert len(model_path) > 0

View file

@ -10,7 +10,8 @@ from buzz.widgets.preferences_dialog.general_preferences_widget import \
class TestGeneralPreferencesWidget:
def test_should_disable_test_button_if_no_api_key(self, qtbot):
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(''))
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(''),
default_export_file_name='')
qtbot.add_widget(widget)
test_button = widget.findChild(QPushButton)
@ -26,7 +27,9 @@ class TestGeneralPreferencesWidget:
assert test_button.isEnabled()
def test_should_test_openai_api_key(self, qtbot):
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store('wrong-api-key'))
widget = GeneralPreferencesWidget(
keyring_store=self.get_keyring_store('wrong-api-key'),
default_export_file_name='')
qtbot.add_widget(widget)
test_button = widget.findChild(QPushButton)

View file

@ -6,7 +6,7 @@ from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog
class TestPreferencesDialog:
def test_create(self, qtbot: QtBot):
dialog = PreferencesDialog(shortcuts={})
dialog = PreferencesDialog(shortcuts={}, default_export_file_name='')
qtbot.add_widget(dialog)
assert dialog.windowTitle() == 'Preferences'