mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 14:45:46 +01:00
Add default file name setting (#559)
This commit is contained in:
parent
64b15f1804
commit
f5f77b3908
23 changed files with 843 additions and 620 deletions
10
buzz/dialogs.py
Normal file
10
buzz/dialogs.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from PyQt6.QtWidgets import QWidget, QMessageBox
|
||||
|
||||
|
||||
def show_model_download_error_dialog(parent: QWidget, error: str):
|
||||
message = parent.tr(
|
||||
'An error occurred while loading the Whisper model') + \
|
||||
f": {error}{'' if error.endswith('.') else '.'}" + \
|
||||
parent.tr("Please retry or check the application logs for more information.")
|
||||
|
||||
QMessageBox.critical(parent, '', message)
|
||||
563
buzz/gui.py
563
buzz/gui.py
|
|
@ -1,51 +1,44 @@
|
|||
import enum
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from enum import auto
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import sounddevice
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (QObject, Qt, QThread,
|
||||
QTimer, QUrl, pyqtSignal, QModelIndex, QPoint,
|
||||
QUrlQuery, QMetaObject, QEvent, QThreadPool)
|
||||
from PyQt6.QtCore import (Qt, QThread,
|
||||
pyqtSignal, QModelIndex, QThreadPool)
|
||||
from PyQt6.QtGui import (QCloseEvent, QIcon,
|
||||
QKeySequence, QTextCursor, QValidator, QKeyEvent, QPainter, QColor)
|
||||
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
|
||||
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
||||
QDialogButtonBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QGroupBox,
|
||||
QFormLayout,
|
||||
QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)
|
||||
QKeySequence, QTextCursor, QPainter, QColor)
|
||||
from PyQt6.QtWidgets import (QApplication, QComboBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QFormLayout,
|
||||
QSizePolicy)
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
from .__version__ import VERSION
|
||||
from .action import Action
|
||||
from .assets import get_asset_path
|
||||
from .dialogs import show_model_download_error_dialog
|
||||
from .widgets.icon import Icon, BUZZ_ICON_PATH
|
||||
from .locale import _
|
||||
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, \
|
||||
ModelDownloader
|
||||
from .paths import file_paths_as_title
|
||||
from .recording import RecordingAmplitudeListener
|
||||
from .settings.settings import Settings, APP_NAME
|
||||
from .settings.shortcut import Shortcut
|
||||
from .settings.shortcut_settings import ShortcutSettings
|
||||
from .store.keyring_store import KeyringStore
|
||||
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
|
||||
Task,
|
||||
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Task,
|
||||
TranscriptionOptions,
|
||||
FileTranscriptionTask, LOADED_WHISPER_DLL,
|
||||
DEFAULT_WHISPER_TEMPERATURE, LANGUAGES)
|
||||
DEFAULT_WHISPER_TEMPERATURE)
|
||||
from .recording_transcriber import RecordingTranscriber
|
||||
from .file_transcriber_queue_worker import FileTranscriberQueueWorker
|
||||
from .widgets.line_edit import LineEdit
|
||||
from .widgets.menu_bar import MenuBar
|
||||
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
from .widgets.model_type_combo_box import ModelTypeComboBox
|
||||
from .widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
|
||||
from .widgets.toolbar import ToolBar
|
||||
from .widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
|
||||
from .widgets.transcriber.transcription_options_group_box import \
|
||||
TranscriptionOptionsGroupBox
|
||||
from .widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
|
||||
from .widgets.transcription_viewer_widget import TranscriptionViewerWidget
|
||||
|
||||
|
|
@ -102,45 +95,6 @@ class AudioDevicesComboBox(QComboBox):
|
|||
return -1
|
||||
|
||||
|
||||
class LanguagesComboBox(QComboBox):
|
||||
"""LanguagesComboBox displays a list of languages available to use with Whisper"""
|
||||
# language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language"
|
||||
languageChanged = pyqtSignal(str)
|
||||
|
||||
def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
whisper_languages = sorted(
|
||||
[(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1])
|
||||
self.languages = [('', _('Detect Language'))] + whisper_languages
|
||||
|
||||
self.addItems([lang[1] for lang in self.languages])
|
||||
self.currentIndexChanged.connect(self.on_index_changed)
|
||||
|
||||
default_language_key = default_language if default_language != '' else None
|
||||
for i, lang in enumerate(self.languages):
|
||||
if lang[0] == default_language_key:
|
||||
self.setCurrentIndex(i)
|
||||
|
||||
def on_index_changed(self, index: int):
|
||||
self.languageChanged.emit(self.languages[index][0])
|
||||
|
||||
|
||||
class TasksComboBox(QComboBox):
|
||||
"""TasksComboBox displays a list of tasks available to use with Whisper"""
|
||||
taskChanged = pyqtSignal(Task)
|
||||
|
||||
def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(parent, *args)
|
||||
self.tasks = [i for i in Task]
|
||||
self.addItems(map(lambda task: task.value.title(), self.tasks))
|
||||
self.currentIndexChanged.connect(self.on_index_changed)
|
||||
self.setCurrentText(default_task.value.title())
|
||||
|
||||
def on_index_changed(self, index: int):
|
||||
self.taskChanged.emit(self.tasks[index])
|
||||
|
||||
|
||||
class TextDisplayBox(QPlainTextEdit):
|
||||
"""TextDisplayBox is a read-only textbox"""
|
||||
|
||||
|
|
@ -164,181 +118,6 @@ class RecordButton(QPushButton):
|
|||
self.setDefault(False)
|
||||
|
||||
|
||||
def show_model_download_error_dialog(parent: QWidget, error: str):
|
||||
message = parent.tr(
|
||||
'An error occurred while loading the Whisper model') + \
|
||||
f": {error}{'' if error.endswith('.') else '.'}" + \
|
||||
parent.tr("Please retry or check the application logs for more information.")
|
||||
|
||||
QMessageBox.critical(parent, '', message)
|
||||
|
||||
|
||||
class FileTranscriberWidget(QWidget):
|
||||
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
|
||||
model_loader: Optional[ModelDownloader] = None
|
||||
file_transcription_options: FileTranscriptionOptions
|
||||
transcription_options: TranscriptionOptions
|
||||
is_transcribing = False
|
||||
# (TranscriptionOptions, FileTranscriptionOptions, str)
|
||||
triggered = pyqtSignal(tuple)
|
||||
openai_access_token_changed = pyqtSignal(str)
|
||||
settings = Settings()
|
||||
|
||||
def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
|
||||
super().__init__(parent, flags)
|
||||
|
||||
self.setWindowTitle(file_paths_as_title(file_paths))
|
||||
|
||||
openai_access_token = KeyringStore().get_password(KeyringStore.Key.OPENAI_API_KEY)
|
||||
|
||||
self.file_paths = file_paths
|
||||
default_language = self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='')
|
||||
self.transcription_options = TranscriptionOptions(
|
||||
openai_access_token=openai_access_token,
|
||||
model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL, default_value=TranscriptionModel()),
|
||||
task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE),
|
||||
language=default_language if default_language != '' else None,
|
||||
initial_prompt=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''),
|
||||
temperature=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
|
||||
default_value=DEFAULT_WHISPER_TEMPERATURE),
|
||||
word_level_timings=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
default_value=False))
|
||||
default_export_format_states: List[str] = self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
|
||||
default_value=[])
|
||||
self.file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=self.file_paths,
|
||||
output_formats=set([OutputFormat(output_format) for output_format in default_export_format_states]))
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
transcription_options_group_box = TranscriptionOptionsGroupBox(
|
||||
default_transcription_options=self.transcription_options, parent=self)
|
||||
transcription_options_group_box.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
|
||||
self.word_level_timings_checkbox = QCheckBox(_('Word-level timings'))
|
||||
self.word_level_timings_checkbox.setChecked(
|
||||
self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, default_value=False))
|
||||
self.word_level_timings_checkbox.stateChanged.connect(
|
||||
self.on_word_level_timings_changed)
|
||||
|
||||
file_transcription_layout = QFormLayout()
|
||||
file_transcription_layout.addRow('', self.word_level_timings_checkbox)
|
||||
|
||||
export_format_layout = QHBoxLayout()
|
||||
for output_format in OutputFormat:
|
||||
export_format_checkbox = QCheckBox(f'{output_format.value.upper()}', parent=self)
|
||||
export_format_checkbox.setChecked(output_format in self.file_transcription_options.output_formats)
|
||||
export_format_checkbox.stateChanged.connect(self.get_on_checkbox_state_changed_callback(output_format))
|
||||
export_format_layout.addWidget(export_format_checkbox)
|
||||
|
||||
file_transcription_layout.addRow('Export:', export_format_layout)
|
||||
|
||||
self.run_button = QPushButton(_('Run'), self)
|
||||
self.run_button.setDefault(True)
|
||||
self.run_button.clicked.connect(self.on_click_run)
|
||||
|
||||
layout.addWidget(transcription_options_group_box)
|
||||
layout.addLayout(file_transcription_layout)
|
||||
layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setFixedSize(self.sizeHint())
|
||||
|
||||
def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat):
|
||||
def on_checkbox_state_changed(state: int):
|
||||
if state == Qt.CheckState.Checked.value:
|
||||
self.file_transcription_options.output_formats.add(output_format)
|
||||
elif state == Qt.CheckState.Unchecked.value:
|
||||
self.file_transcription_options.output_formats.remove(output_format)
|
||||
|
||||
return on_checkbox_state_changed
|
||||
|
||||
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
|
||||
self.transcription_options = transcription_options
|
||||
self.word_level_timings_checkbox.setDisabled(
|
||||
self.transcription_options.model.model_type == ModelType.HUGGING_FACE or self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API)
|
||||
if self.transcription_options.openai_access_token != '':
|
||||
self.openai_access_token_changed.emit(self.transcription_options.openai_access_token)
|
||||
|
||||
def on_click_run(self):
|
||||
self.run_button.setDisabled(True)
|
||||
|
||||
model_path = self.transcription_options.model.get_local_model_path()
|
||||
if model_path is not None:
|
||||
self.on_model_loaded(model_path)
|
||||
return
|
||||
|
||||
self.model_loader = ModelDownloader(model=self.transcription_options.model)
|
||||
self.model_loader.signals.progress.connect(self.on_download_model_progress)
|
||||
self.model_loader.signals.error.connect(self.on_download_model_error)
|
||||
self.model_loader.signals.finished.connect(self.on_model_loaded)
|
||||
QThreadPool().globalInstance().start(self.model_loader)
|
||||
|
||||
def on_model_loaded(self, model_path: str):
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
self.triggered.emit((self.transcription_options,
|
||||
self.file_transcription_options, model_path))
|
||||
self.close()
|
||||
|
||||
def on_download_model_progress(self, progress: Tuple[float, float]):
|
||||
(current_size, total_size) = progress
|
||||
|
||||
if self.model_download_progress_dialog is None:
|
||||
self.model_download_progress_dialog = ModelDownloadProgressDialog(
|
||||
model_type=self.transcription_options.model.model_type, parent=self)
|
||||
self.model_download_progress_dialog.canceled.connect(
|
||||
self.on_cancel_model_progress_dialog)
|
||||
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
self.reset_model_download()
|
||||
show_model_download_error_dialog(self, error)
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
def reset_transcriber_controls(self):
|
||||
self.run_button.setDisabled(False)
|
||||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.cancel()
|
||||
self.reset_model_download()
|
||||
|
||||
def reset_model_download(self):
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.close()
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
def on_word_level_timings_changed(self, value: int):
|
||||
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
|
||||
|
||||
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.cancel()
|
||||
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, self.transcription_options.temperature)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, self.transcription_options.initial_prompt)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL, self.transcription_options.model)
|
||||
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
value=self.transcription_options.word_level_timings)
|
||||
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
|
||||
value=[export_format.value for export_format in
|
||||
self.file_transcription_options.output_formats])
|
||||
|
||||
super().closeEvent(event)
|
||||
|
||||
|
||||
class AdvancedSettingsButton(QPushButton):
|
||||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
super().__init__('Advanced...', parent)
|
||||
|
||||
|
||||
class AudioMeterWidget(QWidget):
|
||||
current_amplitude: float
|
||||
BAR_WIDTH = 2
|
||||
|
|
@ -730,6 +509,9 @@ class MainWindow(QMainWindow):
|
|||
|
||||
self.shortcut_settings = ShortcutSettings(settings=self.settings)
|
||||
self.shortcuts = self.shortcut_settings.load()
|
||||
self.default_export_file_name = self.settings.value(
|
||||
Settings.Key.DEFAULT_EXPORT_FILE_NAME,
|
||||
'{{ input_file_name }} ({{ task }}d on {{ date_time }})')
|
||||
|
||||
self.tasks = {}
|
||||
self.tasks_changed.connect(self.on_tasks_changed)
|
||||
|
|
@ -742,11 +524,14 @@ class MainWindow(QMainWindow):
|
|||
self.addToolBar(self.toolbar)
|
||||
self.setUnifiedTitleAndToolBarOnMac(True)
|
||||
|
||||
self.menu_bar = MenuBar(shortcuts=self.shortcuts, parent=self)
|
||||
self.menu_bar = MenuBar(shortcuts=self.shortcuts,
|
||||
default_export_file_name=self.default_export_file_name,
|
||||
parent=self)
|
||||
self.menu_bar.import_action_triggered.connect(
|
||||
self.on_new_transcription_action_triggered)
|
||||
self.menu_bar.shortcuts_changed.connect(self.on_shortcuts_changed)
|
||||
self.menu_bar.openai_api_key_changed.connect(self.on_openai_access_token_changed)
|
||||
self.menu_bar.default_export_file_name_changed.connect(self.default_export_file_name_changed)
|
||||
self.setMenuBar(self.menu_bar)
|
||||
|
||||
self.table_widget = TranscriptionTasksTableWidget(self)
|
||||
|
|
@ -840,6 +625,7 @@ class MainWindow(QMainWindow):
|
|||
|
||||
def open_file_transcriber_widget(self, file_paths: List[str]):
|
||||
file_transcriber_window = FileTranscriberWidget(file_paths=file_paths,
|
||||
default_output_file_name=self.default_export_file_name,
|
||||
parent=self,
|
||||
flags=Qt.WindowType.Window)
|
||||
file_transcriber_window.triggered.connect(
|
||||
|
|
@ -851,6 +637,10 @@ class MainWindow(QMainWindow):
|
|||
def on_openai_access_token_changed(access_token: str):
|
||||
KeyringStore().set_password(KeyringStore.Key.OPENAI_API_KEY, access_token)
|
||||
|
||||
def default_export_file_name_changed(self, default_export_file_name: str):
|
||||
self.default_export_file_name = default_export_file_name
|
||||
self.settings.set_value(Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name)
|
||||
|
||||
def open_transcript_viewer(self):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
for selected_row in selected_rows:
|
||||
|
|
@ -933,241 +723,6 @@ class MainWindow(QMainWindow):
|
|||
super().closeEvent(event)
|
||||
|
||||
|
||||
# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py
|
||||
class HuggingFaceSearchLineEdit(LineEdit):
|
||||
model_selected = pyqtSignal(str)
|
||||
popup: QListWidget
|
||||
|
||||
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
|
||||
parent: Optional[QWidget] = None):
|
||||
super().__init__('', parent)
|
||||
|
||||
self.setMinimumWidth(150)
|
||||
self.setPlaceholderText('openai/whisper-tiny')
|
||||
|
||||
self.timer = QTimer(self)
|
||||
self.timer.setSingleShot(True)
|
||||
self.timer.setInterval(250)
|
||||
self.timer.timeout.connect(self.fetch_models)
|
||||
|
||||
# Restart debounce timer each time editor text changes
|
||||
self.textEdited.connect(self.timer.start)
|
||||
self.textEdited.connect(self.on_text_edited)
|
||||
|
||||
if network_access_manager is None:
|
||||
network_access_manager = QNetworkAccessManager(self)
|
||||
|
||||
self.network_manager = network_access_manager
|
||||
self.network_manager.finished.connect(self.on_request_response)
|
||||
|
||||
self.popup = QListWidget()
|
||||
self.popup.setWindowFlags(Qt.WindowType.Popup)
|
||||
self.popup.setFocusProxy(self)
|
||||
self.popup.setMouseTracking(True)
|
||||
self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
|
||||
self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
||||
self.popup.installEventFilter(self)
|
||||
self.popup.itemClicked.connect(self.on_select_item)
|
||||
|
||||
def on_text_edited(self, text: str):
|
||||
self.model_selected.emit(text)
|
||||
|
||||
def on_select_item(self):
|
||||
self.popup.hide()
|
||||
self.setFocus()
|
||||
|
||||
item = self.popup.currentItem()
|
||||
self.setText(item.text())
|
||||
QMetaObject.invokeMethod(self, 'returnPressed')
|
||||
self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole))
|
||||
|
||||
def fetch_models(self):
|
||||
text = self.text()
|
||||
if len(text) < 3:
|
||||
return
|
||||
|
||||
url = QUrl("https://huggingface.co/api/models")
|
||||
|
||||
query = QUrlQuery()
|
||||
query.addQueryItem("filter", "whisper")
|
||||
query.addQueryItem("search", text)
|
||||
|
||||
url.setQuery(query)
|
||||
|
||||
return self.network_manager.get(QNetworkRequest(url))
|
||||
|
||||
def on_popup_selected(self):
|
||||
self.timer.stop()
|
||||
|
||||
def on_request_response(self, network_reply: QNetworkReply):
|
||||
if network_reply.error() != QNetworkReply.NetworkError.NoError:
|
||||
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
|
||||
return
|
||||
|
||||
models = json.loads(network_reply.readAll().data())
|
||||
|
||||
self.popup.setUpdatesEnabled(False)
|
||||
self.popup.clear()
|
||||
|
||||
for model in models:
|
||||
model_id = model.get('id')
|
||||
|
||||
item = QListWidgetItem(self.popup)
|
||||
item.setText(model_id)
|
||||
item.setData(Qt.ItemDataRole.UserRole, model_id)
|
||||
|
||||
self.popup.setCurrentItem(self.popup.item(0))
|
||||
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
|
||||
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
|
||||
self.popup.setUpdatesEnabled(True)
|
||||
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
|
||||
self.popup.setFocus()
|
||||
self.popup.show()
|
||||
|
||||
def eventFilter(self, target: QObject, event: QEvent):
|
||||
if hasattr(self, 'popup') is False or target != self.popup:
|
||||
return False
|
||||
|
||||
if event.type() == QEvent.Type.MouseButtonPress:
|
||||
self.popup.hide()
|
||||
self.setFocus()
|
||||
return True
|
||||
|
||||
if isinstance(event, QKeyEvent):
|
||||
key = event.key()
|
||||
if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]:
|
||||
if self.popup.currentItem() is not None:
|
||||
self.on_select_item()
|
||||
return True
|
||||
|
||||
if key == Qt.Key.Key_Escape:
|
||||
self.setFocus()
|
||||
self.popup.hide()
|
||||
return True
|
||||
|
||||
if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp,
|
||||
Qt.Key.Key_PageDown]:
|
||||
return False
|
||||
|
||||
self.setFocus()
|
||||
self.event(event)
|
||||
self.popup.hide()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TranscriptionOptionsGroupBox(QGroupBox):
|
||||
transcription_options: TranscriptionOptions
|
||||
transcription_options_changed = pyqtSignal(TranscriptionOptions)
|
||||
|
||||
def __init__(self, default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
|
||||
model_types: Optional[List[ModelType]] = None,
|
||||
parent: Optional[QWidget] = None):
|
||||
super().__init__(title='', parent=parent)
|
||||
self.transcription_options = default_transcription_options
|
||||
|
||||
self.form_layout = QFormLayout(self)
|
||||
|
||||
self.tasks_combo_box = TasksComboBox(
|
||||
default_task=self.transcription_options.task,
|
||||
parent=self)
|
||||
self.tasks_combo_box.taskChanged.connect(self.on_task_changed)
|
||||
|
||||
self.languages_combo_box = LanguagesComboBox(
|
||||
default_language=self.transcription_options.language,
|
||||
parent=self)
|
||||
self.languages_combo_box.languageChanged.connect(
|
||||
self.on_language_changed)
|
||||
|
||||
self.advanced_settings_button = AdvancedSettingsButton(self)
|
||||
self.advanced_settings_button.clicked.connect(
|
||||
self.open_advanced_settings)
|
||||
|
||||
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
|
||||
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)
|
||||
|
||||
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
|
||||
default_model=default_transcription_options.model.model_type,
|
||||
parent=self)
|
||||
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
|
||||
|
||||
self.whisper_model_size_combo_box = QComboBox(self)
|
||||
self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize])
|
||||
if default_transcription_options.model.whisper_model_size is not None:
|
||||
self.whisper_model_size_combo_box.setCurrentText(
|
||||
default_transcription_options.model.whisper_model_size.value.title())
|
||||
self.whisper_model_size_combo_box.currentTextChanged.connect(self.on_whisper_model_size_changed)
|
||||
|
||||
self.openai_access_token_edit = OpenAIAPIKeyLineEdit(key=default_transcription_options.openai_access_token,
|
||||
parent=self)
|
||||
self.openai_access_token_edit.key_changed.connect(self.on_openai_access_token_edit_changed)
|
||||
|
||||
self.form_layout.addRow(_('Model:'), self.model_type_combo_box)
|
||||
self.form_layout.addRow('', self.whisper_model_size_combo_box)
|
||||
self.form_layout.addRow('', self.hugging_face_search_line_edit)
|
||||
self.form_layout.addRow('Access Token:', self.openai_access_token_edit)
|
||||
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
|
||||
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
|
||||
|
||||
self.reset_visible_rows()
|
||||
|
||||
self.form_layout.addRow('', self.advanced_settings_button)
|
||||
|
||||
self.setLayout(self.form_layout)
|
||||
|
||||
def on_openai_access_token_edit_changed(self, access_token: str):
|
||||
self.transcription_options.openai_access_token = access_token
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_language_changed(self, language: str):
|
||||
self.transcription_options.language = language
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_task_changed(self, task: Task):
|
||||
self.transcription_options.task = task
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_temperature_changed(self, temperature: Tuple[float, ...]):
|
||||
self.transcription_options.temperature = temperature
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_initial_prompt_changed(self, initial_prompt: str):
|
||||
self.transcription_options.initial_prompt = initial_prompt
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def open_advanced_settings(self):
|
||||
dialog = AdvancedSettingsDialog(
|
||||
transcription_options=self.transcription_options, parent=self)
|
||||
dialog.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
dialog.exec()
|
||||
|
||||
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
|
||||
self.transcription_options = transcription_options
|
||||
self.transcription_options_changed.emit(transcription_options)
|
||||
|
||||
def reset_visible_rows(self):
|
||||
model_type = self.transcription_options.model.model_type
|
||||
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE)
|
||||
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
|
||||
(model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP) or (
|
||||
model_type == ModelType.FASTER_WHISPER))
|
||||
self.form_layout.setRowVisible(self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API)
|
||||
|
||||
def on_model_type_changed(self, model_type: ModelType):
|
||||
self.transcription_options.model.model_type = model_type
|
||||
self.reset_visible_rows()
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_whisper_model_size_changed(self, text: str):
|
||||
model_size = WhisperModelSize(text.lower())
|
||||
self.transcription_options.model.whisper_model_size = model_size
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_hugging_face_model_changed(self, model: str):
|
||||
self.transcription_options.model.hugging_face_model_id = model
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
|
||||
class Application(QApplication):
|
||||
window: MainWindow
|
||||
|
|
@ -1183,73 +738,3 @@ class Application(QApplication):
|
|||
|
||||
def add_task(self, task: FileTranscriptionTask):
|
||||
self.window.add_task(task)
|
||||
|
||||
|
||||
class AdvancedSettingsDialog(QDialog):
|
||||
transcription_options: TranscriptionOptions
|
||||
transcription_options_changed = pyqtSignal(TranscriptionOptions)
|
||||
|
||||
def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.transcription_options = transcription_options
|
||||
|
||||
self.setWindowTitle(_('Advanced Settings'))
|
||||
|
||||
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
|
||||
QDialogButtonBox.StandardButton.Ok), self)
|
||||
button_box.accepted.connect(self.accept)
|
||||
|
||||
layout = QFormLayout(self)
|
||||
|
||||
default_temperature_text = ', '.join(
|
||||
[str(temp) for temp in transcription_options.temperature])
|
||||
self.temperature_line_edit = LineEdit(default_temperature_text, self)
|
||||
self.temperature_line_edit.setPlaceholderText(
|
||||
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"'))
|
||||
self.temperature_line_edit.setMinimumWidth(170)
|
||||
self.temperature_line_edit.textChanged.connect(
|
||||
self.on_temperature_changed)
|
||||
self.temperature_line_edit.setValidator(TemperatureValidator(self))
|
||||
self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER)
|
||||
|
||||
self.initial_prompt_text_edit = QPlainTextEdit(
|
||||
transcription_options.initial_prompt, self)
|
||||
self.initial_prompt_text_edit.textChanged.connect(
|
||||
self.on_initial_prompt_changed)
|
||||
self.initial_prompt_text_edit.setEnabled(
|
||||
transcription_options.model.model_type == ModelType.WHISPER)
|
||||
|
||||
layout.addRow(_('Temperature:'), self.temperature_line_edit)
|
||||
layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit)
|
||||
layout.addWidget(button_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setFixedSize(self.sizeHint())
|
||||
|
||||
def on_temperature_changed(self, text: str):
|
||||
try:
|
||||
temperatures = [float(temp.strip()) for temp in text.split(',')]
|
||||
self.transcription_options.temperature = tuple(temperatures)
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def on_initial_prompt_changed(self):
|
||||
self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText()
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
|
||||
class TemperatureValidator(QValidator):
|
||||
def __init__(self, parent: Optional[QObject] = ...) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]:
|
||||
try:
|
||||
temp_strings = [temp.strip() for temp in text.split(',')]
|
||||
if temp_strings[-1] == '':
|
||||
return QValidator.State.Intermediate, text, cursor_position
|
||||
_ = [float(temp) for temp in temp_strings]
|
||||
return QValidator.State.Acceptable, text, cursor_position
|
||||
except ValueError:
|
||||
return QValidator.State.Invalid, text, cursor_position
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from platformdirs import user_cache_dir
|
|||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class WhisperModelSize(enum.Enum):
|
||||
class WhisperModelSize(str, enum.Enum):
|
||||
TINY = 'tiny'
|
||||
BASE = 'base'
|
||||
SMALL = 'small'
|
||||
|
|
|
|||
|
|
@ -25,14 +25,18 @@ class Settings:
|
|||
FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = 'file-transcriber/word-level-timings'
|
||||
FILE_TRANSCRIBER_EXPORT_FORMATS = 'file-transcriber/export-formats'
|
||||
|
||||
DEFAULT_EXPORT_FILE_NAME = 'transcriber/default-export-file-name'
|
||||
|
||||
SHORTCUTS = 'shortcuts'
|
||||
|
||||
def set_value(self, key: Key, value: typing.Any) -> None:
|
||||
self.settings.setValue(key.value, value)
|
||||
|
||||
def value(self, key: Key, default_value: typing.Any, value_type: typing.Optional[type] = None) -> typing.Any:
|
||||
def value(self, key: Key, default_value: typing.Any,
|
||||
value_type: typing.Optional[type] = None) -> typing.Any:
|
||||
return self.settings.value(key.value, default_value,
|
||||
value_type if value_type is not None else type(default_value))
|
||||
value_type if value_type is not None else type(
|
||||
default_value))
|
||||
|
||||
def clear(self):
|
||||
self.settings.clear()
|
||||
|
|
|
|||
|
|
@ -66,13 +66,15 @@ class TranscriptionOptions:
|
|||
word_level_timings: bool = False
|
||||
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
|
||||
initial_prompt: str = ''
|
||||
openai_access_token: str = field(default='', metadata=config(exclude=Exclude.ALWAYS))
|
||||
openai_access_token: str = field(default='',
|
||||
metadata=config(exclude=Exclude.ALWAYS))
|
||||
|
||||
|
||||
@dataclass()
|
||||
class FileTranscriptionOptions:
|
||||
file_paths: List[str]
|
||||
output_formats: Set['OutputFormat'] = field(default_factory=set)
|
||||
default_output_file_name: str = ''
|
||||
|
||||
|
||||
@dataclass_json
|
||||
|
|
@ -127,12 +129,11 @@ class FileTranscriber(QObject):
|
|||
self.completed.emit(segments)
|
||||
|
||||
for output_format in self.transcription_task.file_transcription_options.output_formats:
|
||||
default_path = get_default_output_file_path(
|
||||
task=self.transcription_task.transcription_options.task,
|
||||
input_file_path=self.transcription_task.file_path,
|
||||
output_format=output_format)
|
||||
default_path = get_default_output_file_path(task=self.transcription_task,
|
||||
output_format=output_format)
|
||||
|
||||
write_output(path=default_path, segments=segments, output_format=output_format)
|
||||
write_output(path=default_path, segments=segments,
|
||||
output_format=output_format)
|
||||
|
||||
@abstractmethod
|
||||
def transcribe(self) -> List[Segment]:
|
||||
|
|
@ -172,17 +173,22 @@ class WhisperCppFileTranscriber(FileTranscriber):
|
|||
logging.debug(
|
||||
'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, '
|
||||
'word level timings = %s',
|
||||
self.file_path, self.language, self.task, model_path, self.word_level_timings)
|
||||
self.file_path, self.language, self.task, model_path,
|
||||
self.word_level_timings)
|
||||
|
||||
audio = whisper.audio.load_audio(self.file_path)
|
||||
self.duration_audio_ms = len(audio) * 1000 / whisper.audio.SAMPLE_RATE
|
||||
|
||||
whisper_params = whisper_cpp_params(language=self.language if self.language is not None else '', task=self.task,
|
||||
word_level_timings=self.word_level_timings)
|
||||
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(id(self.state))
|
||||
whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(self.encoder_begin_callback)
|
||||
whisper_params = whisper_cpp_params(
|
||||
language=self.language if self.language is not None else '', task=self.task,
|
||||
word_level_timings=self.word_level_timings)
|
||||
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(
|
||||
id(self.state))
|
||||
whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(
|
||||
self.encoder_begin_callback)
|
||||
whisper_params.new_segment_callback_user_data = ctypes.c_void_p(id(self.state))
|
||||
whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(self.new_segment_callback)
|
||||
whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(
|
||||
self.new_segment_callback)
|
||||
|
||||
model = WhisperCpp(model=model_path)
|
||||
result = model.transcribe(audio=self.file_path, params=whisper_params)
|
||||
|
|
@ -199,13 +205,15 @@ class WhisperCppFileTranscriber(FileTranscriber):
|
|||
# t1 seems to sometimes be larger than the duration when the
|
||||
# audio ends in silence. Trim to fix the displayed progress.
|
||||
progress = min(t1 * 10, self.duration_audio_ms)
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, ctypes.py_object).value
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
|
||||
ctypes.py_object).value
|
||||
if state.running:
|
||||
self.progress.emit((progress, self.duration_audio_ms))
|
||||
|
||||
@staticmethod
|
||||
def encoder_begin_callback(_ctx, _state, user_data):
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, ctypes.py_object).value
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
|
||||
ctypes.py_object).value
|
||||
return state.running == 1
|
||||
|
||||
def stop(self):
|
||||
|
|
@ -219,8 +227,10 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
self.task = task.transcription_options.task
|
||||
|
||||
def transcribe(self) -> List[Segment]:
|
||||
logging.debug('Starting OpenAI Whisper API file transcription, file path = %s, task = %s', self.file_path,
|
||||
self.task)
|
||||
logging.debug(
|
||||
'Starting OpenAI Whisper API file transcription, file path = %s, task = %s',
|
||||
self.file_path,
|
||||
self.task)
|
||||
|
||||
wav_file = tempfile.mktemp() + '.wav'
|
||||
(
|
||||
|
|
@ -235,14 +245,18 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
language = self.transcription_task.transcription_options.language
|
||||
response_format = "verbose_json"
|
||||
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
|
||||
transcript = openai.Audio.translate("whisper-1", audio_file, response_format=response_format,
|
||||
transcript = openai.Audio.translate("whisper-1", audio_file,
|
||||
response_format=response_format,
|
||||
language=language)
|
||||
else:
|
||||
transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format=response_format,
|
||||
transcript = openai.Audio.transcribe("whisper-1", audio_file,
|
||||
response_format=response_format,
|
||||
language=language)
|
||||
|
||||
segments = [Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for segment in
|
||||
transcript["segments"]]
|
||||
segments = [
|
||||
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for
|
||||
segment in
|
||||
transcript["segments"]]
|
||||
return segments
|
||||
|
||||
def stop(self):
|
||||
|
|
@ -273,7 +287,8 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
self.current_process = multiprocessing.Process(target=self.transcribe_whisper,
|
||||
args=(send_pipe, self.transcription_task))
|
||||
args=(send_pipe,
|
||||
self.transcription_task))
|
||||
if not self.stopped:
|
||||
self.current_process.start()
|
||||
self.started_process = True
|
||||
|
|
@ -291,7 +306,8 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
|
||||
logging.debug(
|
||||
'whisper process completed with code = %s, time taken = %s, number of segments = %s',
|
||||
self.current_process.exitcode, datetime.datetime.now() - time_started, len(self.segments))
|
||||
self.current_process.exitcode, datetime.datetime.now() - time_started,
|
||||
len(self.segments))
|
||||
|
||||
if self.current_process.exitcode != 0:
|
||||
raise Exception('Unknown error')
|
||||
|
|
@ -299,7 +315,8 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
return self.segments
|
||||
|
||||
@classmethod
|
||||
def transcribe_whisper(cls, stderr_conn: Connection, task: FileTranscriptionTask) -> None:
|
||||
def transcribe_whisper(cls, stderr_conn: Connection,
|
||||
task: FileTranscriptionTask) -> None:
|
||||
with pipe_stderr(stderr_conn):
|
||||
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
|
||||
segments = cls.transcribe_hugging_face(task)
|
||||
|
|
@ -308,7 +325,8 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
elif task.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
segments = cls.transcribe_openai_whisper(task)
|
||||
else:
|
||||
raise Exception(f"Invalid model type: {task.transcription_options.model.model_type}")
|
||||
raise Exception(
|
||||
f"Invalid model type: {task.transcription_options.model.model_type}")
|
||||
|
||||
segments_json = json.dumps(
|
||||
segments, ensure_ascii=True, default=vars)
|
||||
|
|
@ -321,7 +339,8 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
model = transformers_whisper.load_model(task.model_path)
|
||||
language = task.transcription_options.language if task.transcription_options.language is not None else 'en'
|
||||
result = model.transcribe(audio=task.file_path, language=language,
|
||||
task=task.transcription_options.task.value, verbose=False)
|
||||
task=task.transcription_options.task.value,
|
||||
verbose=False)
|
||||
return [
|
||||
Segment(
|
||||
start=int(segment.get('start') * 1000),
|
||||
|
|
@ -368,7 +387,8 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
stable_whisper.modify_model(model)
|
||||
result = model.transcribe(
|
||||
audio=task.file_path, language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value, temperature=task.transcription_options.temperature,
|
||||
task=task.transcription_options.task.value,
|
||||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt, pbar=True)
|
||||
segments = stable_whisper.group_word_timestamps(result)
|
||||
return [Segment(
|
||||
|
|
@ -423,7 +443,8 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
|
||||
def write_output(path: str, segments: List[Segment], output_format: OutputFormat):
|
||||
logging.debug(
|
||||
'Writing transcription output, path = %s, output format = %s, number of segments = %s', path, output_format,
|
||||
'Writing transcription output, path = %s, output format = %s, number of segments = %s',
|
||||
path, output_format,
|
||||
len(segments))
|
||||
|
||||
with open(path, 'w', encoding='utf-8') as file:
|
||||
|
|
@ -473,8 +494,22 @@ SUPPORTED_OUTPUT_FORMATS = 'Audio files (*.mp3 *.wav *.m4a *.ogg);;\
|
|||
Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)'
|
||||
|
||||
|
||||
def get_default_output_file_path(task: Task, input_file_path: str, output_format: OutputFormat):
|
||||
return f'{os.path.splitext(input_file_path)[0]} ({task.value.title()}d on {datetime.datetime.now():%d-%b-%Y %H-%M-%S}).{output_format.value}'
|
||||
def get_default_output_file_path(task: FileTranscriptionTask,
|
||||
output_format: OutputFormat):
|
||||
input_file_name = os.path.splitext(task.file_path)[0]
|
||||
date_time_now = datetime.datetime.now().strftime('%d-%b-%Y %H-%M-%S')
|
||||
return (task.file_transcription_options.default_output_file_name
|
||||
.replace('{{ input_file_name }}', input_file_name)
|
||||
.replace('{{ task }}', task.transcription_options.task.value)
|
||||
.replace('{{ language }}', task.transcription_options.language or '')
|
||||
.replace('{{ model_type }}',
|
||||
task.transcription_options.model.model_type.value)
|
||||
.replace('{{ model_size }}',
|
||||
task.transcription_options.model.whisper_model_size.value if
|
||||
task.transcription_options.model.whisper_model_size is not None else
|
||||
'')
|
||||
.replace('{{ date_time }}', date_time_now)
|
||||
+ f".{output_format.value}")
|
||||
|
||||
|
||||
def whisper_cpp_params(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import webbrowser
|
||||
from typing import Dict
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal
|
||||
|
|
@ -15,11 +16,14 @@ class MenuBar(QMenuBar):
|
|||
import_action_triggered = pyqtSignal()
|
||||
shortcuts_changed = pyqtSignal(dict)
|
||||
openai_api_key_changed = pyqtSignal(str)
|
||||
default_export_file_name_changed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, shortcuts: Dict[str, str], parent: QWidget):
|
||||
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
|
||||
parent: QWidget):
|
||||
super().__init__(parent)
|
||||
|
||||
self.shortcuts = shortcuts
|
||||
self.default_export_file_name = default_export_file_name
|
||||
|
||||
self.import_action = QAction(_("Import Media File..."), self)
|
||||
self.import_action.triggered.connect(
|
||||
|
|
@ -31,6 +35,9 @@ class MenuBar(QMenuBar):
|
|||
self.preferences_action = QAction(_("Preferences..."), self)
|
||||
self.preferences_action.triggered.connect(self.on_preferences_action_triggered)
|
||||
|
||||
help_action = QAction(f'{_("Help")}', self)
|
||||
help_action.triggered.connect(self.on_help_action_triggered)
|
||||
|
||||
self.set_shortcuts(shortcuts)
|
||||
|
||||
file_menu = self.addMenu(_("File"))
|
||||
|
|
@ -38,6 +45,7 @@ class MenuBar(QMenuBar):
|
|||
|
||||
help_menu = self.addMenu(_("Help"))
|
||||
help_menu.addAction(about_action)
|
||||
help_menu.addAction(help_action)
|
||||
help_menu.addAction(self.preferences_action)
|
||||
|
||||
def on_import_action_triggered(self):
|
||||
|
|
@ -48,11 +56,18 @@ class MenuBar(QMenuBar):
|
|||
about_dialog.open()
|
||||
|
||||
def on_preferences_action_triggered(self):
|
||||
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts, parent=self)
|
||||
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts,
|
||||
default_export_file_name=self.default_export_file_name,
|
||||
parent=self)
|
||||
preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed)
|
||||
preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed)
|
||||
preferences_dialog.default_export_file_name_changed.connect(
|
||||
self.default_export_file_name_changed)
|
||||
preferences_dialog.open()
|
||||
|
||||
def on_help_action_triggered(self):
|
||||
webbrowser.open('https://chidiwilliams.github.io/buzz/docs')
|
||||
|
||||
def set_shortcuts(self, shortcuts: Dict[str, str]):
|
||||
self.shortcuts = shortcuts
|
||||
|
||||
|
|
|
|||
|
|
@ -3,33 +3,45 @@ from typing import Optional
|
|||
|
||||
import openai
|
||||
from PyQt6.QtCore import QRunnable, QObject, pyqtSignal, QThreadPool
|
||||
from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox
|
||||
from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox, QLineEdit
|
||||
from openai.error import AuthenticationError
|
||||
|
||||
from buzz.store.keyring_store import KeyringStore
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
|
||||
|
||||
|
||||
class GeneralPreferencesWidget(QWidget):
|
||||
openai_api_key_changed = pyqtSignal(str)
|
||||
default_export_file_name_changed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, keyring_store=KeyringStore(), parent: Optional[QWidget] = None):
|
||||
def __init__(self, default_export_file_name: str, keyring_store=KeyringStore(),
|
||||
parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.openai_api_key = keyring_store.get_password(KeyringStore.Key.OPENAI_API_KEY)
|
||||
self.openai_api_key = keyring_store.get_password(
|
||||
KeyringStore.Key.OPENAI_API_KEY)
|
||||
|
||||
layout = QFormLayout(self)
|
||||
|
||||
self.openai_api_key_line_edit = OpenAIAPIKeyLineEdit(self.openai_api_key, self)
|
||||
self.openai_api_key_line_edit.key_changed.connect(self.on_openai_api_key_changed)
|
||||
self.openai_api_key_line_edit.key_changed.connect(
|
||||
self.on_openai_api_key_changed)
|
||||
|
||||
self.test_openai_api_key_button = QPushButton('Test')
|
||||
self.test_openai_api_key_button.clicked.connect(self.on_click_test_openai_api_key_button)
|
||||
self.test_openai_api_key_button.clicked.connect(
|
||||
self.on_click_test_openai_api_key_button)
|
||||
self.update_test_openai_api_key_button()
|
||||
|
||||
layout.addRow('OpenAI API Key', self.openai_api_key_line_edit)
|
||||
layout.addRow('', self.test_openai_api_key_button)
|
||||
|
||||
default_export_file_name_line_edit = LineEdit(default_export_file_name, self)
|
||||
default_export_file_name_line_edit.textChanged.connect(
|
||||
self.default_export_file_name_changed)
|
||||
default_export_file_name_line_edit.setMinimumWidth(200)
|
||||
layout.addRow('Default export file name', default_export_file_name_line_edit)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def update_test_openai_api_key_button(self):
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import
|
|||
class PreferencesDialog(QDialog):
|
||||
shortcuts_changed = pyqtSignal(dict)
|
||||
openai_api_key_changed = pyqtSignal(str)
|
||||
default_export_file_name_changed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, shortcuts: Dict[str, str],
|
||||
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
|
||||
parent: Optional[QWidget] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
|
|
@ -25,8 +26,11 @@ class PreferencesDialog(QDialog):
|
|||
layout = QVBoxLayout(self)
|
||||
tab_widget = QTabWidget(self)
|
||||
|
||||
general_tab_widget = GeneralPreferencesWidget(parent=self)
|
||||
general_tab_widget = GeneralPreferencesWidget(
|
||||
default_export_file_name=default_export_file_name, parent=self)
|
||||
general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed)
|
||||
general_tab_widget.default_export_file_name_changed.connect(
|
||||
self.default_export_file_name_changed)
|
||||
tab_widget.addTab(general_tab_widget, _('General'))
|
||||
|
||||
models_tab_widget = ModelsPreferencesWidget(parent=self)
|
||||
|
|
|
|||
0
buzz/widgets/transcriber/__init__.py
Normal file
0
buzz/widgets/transcriber/__init__.py
Normal file
8
buzz/widgets/transcriber/advanced_settings_button.py
Normal file
8
buzz/widgets/transcriber/advanced_settings_button.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from typing import Optional
|
||||
|
||||
from PyQt6.QtWidgets import QPushButton, QWidget
|
||||
|
||||
|
||||
class AdvancedSettingsButton(QPushButton):
|
||||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
super().__init__('Advanced...', parent)
|
||||
64
buzz/widgets/transcriber/advanced_settings_dialog.py
Normal file
64
buzz/widgets/transcriber/advanced_settings_dialog.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtWidgets import QDialog, QWidget, QDialogButtonBox, QFormLayout, \
|
||||
QPlainTextEdit
|
||||
|
||||
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.transcriber import TranscriptionOptions
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
|
||||
|
||||
class AdvancedSettingsDialog(QDialog):
|
||||
transcription_options: TranscriptionOptions
|
||||
transcription_options_changed = pyqtSignal(TranscriptionOptions)
|
||||
|
||||
def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.transcription_options = transcription_options
|
||||
|
||||
self.setWindowTitle(_('Advanced Settings'))
|
||||
|
||||
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
|
||||
QDialogButtonBox.StandardButton.Ok), self)
|
||||
button_box.accepted.connect(self.accept)
|
||||
|
||||
layout = QFormLayout(self)
|
||||
|
||||
default_temperature_text = ', '.join(
|
||||
[str(temp) for temp in transcription_options.temperature])
|
||||
self.temperature_line_edit = LineEdit(default_temperature_text, self)
|
||||
self.temperature_line_edit.setPlaceholderText(
|
||||
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"'))
|
||||
self.temperature_line_edit.setMinimumWidth(170)
|
||||
self.temperature_line_edit.textChanged.connect(
|
||||
self.on_temperature_changed)
|
||||
self.temperature_line_edit.setValidator(TemperatureValidator(self))
|
||||
self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER)
|
||||
|
||||
self.initial_prompt_text_edit = QPlainTextEdit(
|
||||
transcription_options.initial_prompt, self)
|
||||
self.initial_prompt_text_edit.textChanged.connect(
|
||||
self.on_initial_prompt_changed)
|
||||
self.initial_prompt_text_edit.setEnabled(
|
||||
transcription_options.model.model_type == ModelType.WHISPER)
|
||||
|
||||
layout.addRow(_('Temperature:'), self.temperature_line_edit)
|
||||
layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit)
|
||||
layout.addWidget(button_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setFixedSize(self.sizeHint())
|
||||
|
||||
def on_temperature_changed(self, text: str):
|
||||
try:
|
||||
temperatures = [float(temp.strip()) for temp in text.split(',')]
|
||||
self.transcription_options.temperature = tuple(temperatures)
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def on_initial_prompt_changed(self):
|
||||
self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText()
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
204
buzz/widgets/transcriber/file_transcriber_widget.py
Normal file
204
buzz/widgets/transcriber/file_transcriber_widget.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
from typing import Optional, List, Tuple
|
||||
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import pyqtSignal, Qt, QThreadPool
|
||||
from PyQt6.QtWidgets import QWidget, QVBoxLayout, QCheckBox, QFormLayout, QHBoxLayout, \
|
||||
QPushButton
|
||||
|
||||
from buzz.dialogs import show_model_download_error_dialog
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelDownloader, TranscriptionModel, ModelType
|
||||
from buzz.paths import file_paths_as_title
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.store.keyring_store import KeyringStore
|
||||
from buzz.transcriber import FileTranscriptionOptions, TranscriptionOptions, Task, \
|
||||
DEFAULT_WHISPER_TEMPERATURE, OutputFormat
|
||||
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
from buzz.widgets.transcriber.transcription_options_group_box import \
|
||||
TranscriptionOptionsGroupBox
|
||||
|
||||
|
||||
class FileTranscriberWidget(QWidget):
|
||||
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
|
||||
model_loader: Optional[ModelDownloader] = None
|
||||
file_transcription_options: FileTranscriptionOptions
|
||||
transcription_options: TranscriptionOptions
|
||||
is_transcribing = False
|
||||
# (TranscriptionOptions, FileTranscriptionOptions, str)
|
||||
triggered = pyqtSignal(tuple)
|
||||
openai_access_token_changed = pyqtSignal(str)
|
||||
settings = Settings()
|
||||
|
||||
def __init__(self, file_paths: List[str],
|
||||
default_output_file_name: str,
|
||||
parent: Optional[QWidget] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
|
||||
super().__init__(parent, flags)
|
||||
|
||||
self.setWindowTitle(file_paths_as_title(file_paths))
|
||||
|
||||
openai_access_token = KeyringStore().get_password(
|
||||
KeyringStore.Key.OPENAI_API_KEY)
|
||||
|
||||
self.file_paths = file_paths
|
||||
default_language = self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='')
|
||||
self.transcription_options = TranscriptionOptions(
|
||||
openai_access_token=openai_access_token,
|
||||
model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL,
|
||||
default_value=TranscriptionModel()),
|
||||
task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK,
|
||||
default_value=Task.TRANSCRIBE),
|
||||
language=default_language if default_language != '' else None,
|
||||
initial_prompt=self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''),
|
||||
temperature=self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
|
||||
default_value=DEFAULT_WHISPER_TEMPERATURE),
|
||||
word_level_timings=self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
default_value=False))
|
||||
default_export_format_states: List[str] = self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
|
||||
default_value=[])
|
||||
self.file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=self.file_paths,
|
||||
output_formats=set([OutputFormat(output_format) for output_format in
|
||||
default_export_format_states]),
|
||||
default_output_file_name=default_output_file_name)
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
transcription_options_group_box = TranscriptionOptionsGroupBox(
|
||||
default_transcription_options=self.transcription_options, parent=self)
|
||||
transcription_options_group_box.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
|
||||
self.word_level_timings_checkbox = QCheckBox(_('Word-level timings'))
|
||||
self.word_level_timings_checkbox.setChecked(
|
||||
self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
default_value=False))
|
||||
self.word_level_timings_checkbox.stateChanged.connect(
|
||||
self.on_word_level_timings_changed)
|
||||
|
||||
file_transcription_layout = QFormLayout()
|
||||
file_transcription_layout.addRow('', self.word_level_timings_checkbox)
|
||||
|
||||
export_format_layout = QHBoxLayout()
|
||||
for output_format in OutputFormat:
|
||||
export_format_checkbox = QCheckBox(f'{output_format.value.upper()}',
|
||||
parent=self)
|
||||
export_format_checkbox.setChecked(
|
||||
output_format in self.file_transcription_options.output_formats)
|
||||
export_format_checkbox.stateChanged.connect(
|
||||
self.get_on_checkbox_state_changed_callback(output_format))
|
||||
export_format_layout.addWidget(export_format_checkbox)
|
||||
|
||||
file_transcription_layout.addRow('Export:', export_format_layout)
|
||||
|
||||
self.run_button = QPushButton(_('Run'), self)
|
||||
self.run_button.setDefault(True)
|
||||
self.run_button.clicked.connect(self.on_click_run)
|
||||
|
||||
layout.addWidget(transcription_options_group_box)
|
||||
layout.addLayout(file_transcription_layout)
|
||||
layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setFixedSize(self.sizeHint())
|
||||
|
||||
def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat):
|
||||
def on_checkbox_state_changed(state: int):
|
||||
if state == Qt.CheckState.Checked.value:
|
||||
self.file_transcription_options.output_formats.add(output_format)
|
||||
elif state == Qt.CheckState.Unchecked.value:
|
||||
self.file_transcription_options.output_formats.remove(output_format)
|
||||
|
||||
return on_checkbox_state_changed
|
||||
|
||||
def on_transcription_options_changed(self,
|
||||
transcription_options: TranscriptionOptions):
|
||||
self.transcription_options = transcription_options
|
||||
self.word_level_timings_checkbox.setDisabled(
|
||||
self.transcription_options.model.model_type == ModelType.HUGGING_FACE or
|
||||
self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API)
|
||||
if self.transcription_options.openai_access_token != '':
|
||||
self.openai_access_token_changed.emit(
|
||||
self.transcription_options.openai_access_token)
|
||||
|
||||
def on_click_run(self):
|
||||
self.run_button.setDisabled(True)
|
||||
|
||||
model_path = self.transcription_options.model.get_local_model_path()
|
||||
if model_path is not None:
|
||||
self.on_model_loaded(model_path)
|
||||
return
|
||||
|
||||
self.model_loader = ModelDownloader(model=self.transcription_options.model)
|
||||
self.model_loader.signals.progress.connect(self.on_download_model_progress)
|
||||
self.model_loader.signals.error.connect(self.on_download_model_error)
|
||||
self.model_loader.signals.finished.connect(self.on_model_loaded)
|
||||
QThreadPool().globalInstance().start(self.model_loader)
|
||||
|
||||
def on_model_loaded(self, model_path: str):
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
self.triggered.emit((self.transcription_options,
|
||||
self.file_transcription_options, model_path))
|
||||
self.close()
|
||||
|
||||
def on_download_model_progress(self, progress: Tuple[float, float]):
|
||||
(current_size, total_size) = progress
|
||||
|
||||
if self.model_download_progress_dialog is None:
|
||||
self.model_download_progress_dialog = ModelDownloadProgressDialog(
|
||||
model_type=self.transcription_options.model.model_type, parent=self)
|
||||
self.model_download_progress_dialog.canceled.connect(
|
||||
self.on_cancel_model_progress_dialog)
|
||||
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.set_value(
|
||||
fraction_completed=current_size / total_size)
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
self.reset_model_download()
|
||||
show_model_download_error_dialog(self, error)
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
def reset_transcriber_controls(self):
|
||||
self.run_button.setDisabled(False)
|
||||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.cancel()
|
||||
self.reset_model_download()
|
||||
|
||||
def reset_model_download(self):
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.close()
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
def on_word_level_timings_changed(self, value: int):
|
||||
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
|
||||
|
||||
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.cancel()
|
||||
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE,
|
||||
self.transcription_options.language)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK,
|
||||
self.transcription_options.task)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
|
||||
self.transcription_options.temperature)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT,
|
||||
self.transcription_options.initial_prompt)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL,
|
||||
self.transcription_options.model)
|
||||
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
value=self.transcription_options.word_level_timings)
|
||||
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
|
||||
value=[export_format.value for export_format in
|
||||
self.file_transcription_options.output_formats])
|
||||
|
||||
super().closeEvent(event)
|
||||
134
buzz/widgets/transcriber/hugging_face_search_line_edit.py
Normal file
134
buzz/widgets/transcriber/hugging_face_search_line_edit.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal, QTimer, Qt, QMetaObject, QUrl, QUrlQuery, QPoint, \
|
||||
QObject, QEvent
|
||||
from PyQt6.QtGui import QKeyEvent
|
||||
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
|
||||
from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem
|
||||
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
|
||||
|
||||
# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py
|
||||
class HuggingFaceSearchLineEdit(LineEdit):
|
||||
model_selected = pyqtSignal(str)
|
||||
popup: QListWidget
|
||||
|
||||
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
|
||||
parent: Optional[QWidget] = None):
|
||||
super().__init__('', parent)
|
||||
|
||||
self.setMinimumWidth(150)
|
||||
self.setPlaceholderText('openai/whisper-tiny')
|
||||
|
||||
self.timer = QTimer(self)
|
||||
self.timer.setSingleShot(True)
|
||||
self.timer.setInterval(250)
|
||||
self.timer.timeout.connect(self.fetch_models)
|
||||
|
||||
# Restart debounce timer each time editor text changes
|
||||
self.textEdited.connect(self.timer.start)
|
||||
self.textEdited.connect(self.on_text_edited)
|
||||
|
||||
if network_access_manager is None:
|
||||
network_access_manager = QNetworkAccessManager(self)
|
||||
|
||||
self.network_manager = network_access_manager
|
||||
self.network_manager.finished.connect(self.on_request_response)
|
||||
|
||||
self.popup = QListWidget()
|
||||
self.popup.setWindowFlags(Qt.WindowType.Popup)
|
||||
self.popup.setFocusProxy(self)
|
||||
self.popup.setMouseTracking(True)
|
||||
self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
|
||||
self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
||||
self.popup.installEventFilter(self)
|
||||
self.popup.itemClicked.connect(self.on_select_item)
|
||||
|
||||
def on_text_edited(self, text: str):
|
||||
self.model_selected.emit(text)
|
||||
|
||||
def on_select_item(self):
|
||||
self.popup.hide()
|
||||
self.setFocus()
|
||||
|
||||
item = self.popup.currentItem()
|
||||
self.setText(item.text())
|
||||
QMetaObject.invokeMethod(self, 'returnPressed')
|
||||
self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole))
|
||||
|
||||
def fetch_models(self):
|
||||
text = self.text()
|
||||
if len(text) < 3:
|
||||
return
|
||||
|
||||
url = QUrl("https://huggingface.co/api/models")
|
||||
|
||||
query = QUrlQuery()
|
||||
query.addQueryItem("filter", "whisper")
|
||||
query.addQueryItem("search", text)
|
||||
|
||||
url.setQuery(query)
|
||||
|
||||
return self.network_manager.get(QNetworkRequest(url))
|
||||
|
||||
def on_popup_selected(self):
|
||||
self.timer.stop()
|
||||
|
||||
def on_request_response(self, network_reply: QNetworkReply):
|
||||
if network_reply.error() != QNetworkReply.NetworkError.NoError:
|
||||
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
|
||||
return
|
||||
|
||||
models = json.loads(network_reply.readAll().data())
|
||||
|
||||
self.popup.setUpdatesEnabled(False)
|
||||
self.popup.clear()
|
||||
|
||||
for model in models:
|
||||
model_id = model.get('id')
|
||||
|
||||
item = QListWidgetItem(self.popup)
|
||||
item.setText(model_id)
|
||||
item.setData(Qt.ItemDataRole.UserRole, model_id)
|
||||
|
||||
self.popup.setCurrentItem(self.popup.item(0))
|
||||
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
|
||||
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
|
||||
self.popup.setUpdatesEnabled(True)
|
||||
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
|
||||
self.popup.setFocus()
|
||||
self.popup.show()
|
||||
|
||||
def eventFilter(self, target: QObject, event: QEvent):
|
||||
if hasattr(self, 'popup') is False or target != self.popup:
|
||||
return False
|
||||
|
||||
if event.type() == QEvent.Type.MouseButtonPress:
|
||||
self.popup.hide()
|
||||
self.setFocus()
|
||||
return True
|
||||
|
||||
if isinstance(event, QKeyEvent):
|
||||
key = event.key()
|
||||
if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]:
|
||||
if self.popup.currentItem() is not None:
|
||||
self.on_select_item()
|
||||
return True
|
||||
|
||||
if key == Qt.Key.Key_Escape:
|
||||
self.setFocus()
|
||||
self.popup.hide()
|
||||
return True
|
||||
|
||||
if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp,
|
||||
Qt.Key.Key_PageDown]:
|
||||
return False
|
||||
|
||||
self.setFocus()
|
||||
self.event(event)
|
||||
self.popup.hide()
|
||||
|
||||
return False
|
||||
31
buzz/widgets/transcriber/languages_combo_box.py
Normal file
31
buzz/widgets/transcriber/languages_combo_box.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtWidgets import QComboBox, QWidget
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.transcriber import LANGUAGES
|
||||
|
||||
|
||||
class LanguagesComboBox(QComboBox):
|
||||
"""LanguagesComboBox displays a list of languages available to use with Whisper"""
|
||||
# language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language"
|
||||
languageChanged = pyqtSignal(str)
|
||||
|
||||
def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
whisper_languages = sorted(
|
||||
[(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1])
|
||||
self.languages = [('', _('Detect Language'))] + whisper_languages
|
||||
|
||||
self.addItems([lang[1] for lang in self.languages])
|
||||
self.currentIndexChanged.connect(self.on_index_changed)
|
||||
|
||||
default_language_key = default_language if default_language != '' else None
|
||||
for i, lang in enumerate(self.languages):
|
||||
if lang[0] == default_language_key:
|
||||
self.setCurrentIndex(i)
|
||||
|
||||
def on_index_changed(self, index: int):
|
||||
self.languageChanged.emit(self.languages[index][0])
|
||||
21
buzz/widgets/transcriber/tasks_combo_box.py
Normal file
21
buzz/widgets/transcriber/tasks_combo_box.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtWidgets import QComboBox, QWidget
|
||||
|
||||
from buzz.transcriber import Task
|
||||
|
||||
|
||||
class TasksComboBox(QComboBox):
|
||||
"""TasksComboBox displays a list of tasks available to use with Whisper"""
|
||||
taskChanged = pyqtSignal(Task)
|
||||
|
||||
def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(parent, *args)
|
||||
self.tasks = [i for i in Task]
|
||||
self.addItems(map(lambda task: task.value.title(), self.tasks))
|
||||
self.currentIndexChanged.connect(self.on_index_changed)
|
||||
self.setCurrentText(default_task.value.title())
|
||||
|
||||
def on_index_changed(self, index: int):
|
||||
self.taskChanged.emit(self.tasks[index])
|
||||
19
buzz/widgets/transcriber/temperature_validator.py
Normal file
19
buzz/widgets/transcriber/temperature_validator.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from PyQt6.QtCore import QObject
|
||||
from PyQt6.QtGui import QValidator
|
||||
|
||||
|
||||
class TemperatureValidator(QValidator):
|
||||
def __init__(self, parent: Optional[QObject] = ...) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]:
|
||||
try:
|
||||
temp_strings = [temp.strip() for temp in text.split(',')]
|
||||
if temp_strings[-1] == '':
|
||||
return QValidator.State.Intermediate, text, cursor_position
|
||||
_ = [float(temp) for temp in temp_strings]
|
||||
return QValidator.State.Acceptable, text, cursor_position
|
||||
except ValueError:
|
||||
return QValidator.State.Invalid, text, cursor_position
|
||||
140
buzz/widgets/transcriber/transcription_options_group_box.py
Normal file
140
buzz/widgets/transcriber/transcription_options_group_box.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
from typing import Optional, List, Tuple
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtWidgets import QGroupBox, QWidget, QFormLayout, QComboBox
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelType, WhisperModelSize
|
||||
from buzz.transcriber import TranscriptionOptions, Task
|
||||
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
|
||||
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
|
||||
from buzz.widgets.transcriber.advanced_settings_button import AdvancedSettingsButton
|
||||
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
|
||||
from buzz.widgets.transcriber.hugging_face_search_line_edit import \
|
||||
HuggingFaceSearchLineEdit
|
||||
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
|
||||
from buzz.widgets.transcriber.tasks_combo_box import TasksComboBox
|
||||
|
||||
|
||||
class TranscriptionOptionsGroupBox(QGroupBox):
|
||||
transcription_options: TranscriptionOptions
|
||||
transcription_options_changed = pyqtSignal(TranscriptionOptions)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
|
||||
model_types: Optional[List[ModelType]] = None,
|
||||
parent: Optional[QWidget] = None):
|
||||
super().__init__(title='', parent=parent)
|
||||
self.transcription_options = default_transcription_options
|
||||
|
||||
self.form_layout = QFormLayout(self)
|
||||
|
||||
self.tasks_combo_box = TasksComboBox(
|
||||
default_task=self.transcription_options.task,
|
||||
parent=self)
|
||||
self.tasks_combo_box.taskChanged.connect(self.on_task_changed)
|
||||
|
||||
self.languages_combo_box = LanguagesComboBox(
|
||||
default_language=self.transcription_options.language,
|
||||
parent=self)
|
||||
self.languages_combo_box.languageChanged.connect(
|
||||
self.on_language_changed)
|
||||
|
||||
self.advanced_settings_button = AdvancedSettingsButton(self)
|
||||
self.advanced_settings_button.clicked.connect(
|
||||
self.open_advanced_settings)
|
||||
|
||||
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
|
||||
self.hugging_face_search_line_edit.model_selected.connect(
|
||||
self.on_hugging_face_model_changed)
|
||||
|
||||
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
|
||||
default_model=default_transcription_options.model.model_type,
|
||||
parent=self)
|
||||
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
|
||||
|
||||
self.whisper_model_size_combo_box = QComboBox(self)
|
||||
self.whisper_model_size_combo_box.addItems(
|
||||
[size.value.title() for size in WhisperModelSize])
|
||||
if default_transcription_options.model.whisper_model_size is not None:
|
||||
self.whisper_model_size_combo_box.setCurrentText(
|
||||
default_transcription_options.model.whisper_model_size.value.title())
|
||||
self.whisper_model_size_combo_box.currentTextChanged.connect(
|
||||
self.on_whisper_model_size_changed)
|
||||
|
||||
self.openai_access_token_edit = OpenAIAPIKeyLineEdit(
|
||||
key=default_transcription_options.openai_access_token,
|
||||
parent=self)
|
||||
self.openai_access_token_edit.key_changed.connect(
|
||||
self.on_openai_access_token_edit_changed)
|
||||
|
||||
self.form_layout.addRow(_('Model:'), self.model_type_combo_box)
|
||||
self.form_layout.addRow('', self.whisper_model_size_combo_box)
|
||||
self.form_layout.addRow('', self.hugging_face_search_line_edit)
|
||||
self.form_layout.addRow('Access Token:', self.openai_access_token_edit)
|
||||
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
|
||||
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
|
||||
|
||||
self.reset_visible_rows()
|
||||
|
||||
self.form_layout.addRow('', self.advanced_settings_button)
|
||||
|
||||
self.setLayout(self.form_layout)
|
||||
|
||||
def on_openai_access_token_edit_changed(self, access_token: str):
|
||||
self.transcription_options.openai_access_token = access_token
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_language_changed(self, language: str):
|
||||
self.transcription_options.language = language
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_task_changed(self, task: Task):
|
||||
self.transcription_options.task = task
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_temperature_changed(self, temperature: Tuple[float, ...]):
|
||||
self.transcription_options.temperature = temperature
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_initial_prompt_changed(self, initial_prompt: str):
|
||||
self.transcription_options.initial_prompt = initial_prompt
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def open_advanced_settings(self):
|
||||
dialog = AdvancedSettingsDialog(
|
||||
transcription_options=self.transcription_options, parent=self)
|
||||
dialog.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
dialog.exec()
|
||||
|
||||
def on_transcription_options_changed(self,
|
||||
transcription_options: TranscriptionOptions):
|
||||
self.transcription_options = transcription_options
|
||||
self.transcription_options_changed.emit(transcription_options)
|
||||
|
||||
def reset_visible_rows(self):
|
||||
model_type = self.transcription_options.model.model_type
|
||||
self.form_layout.setRowVisible(self.hugging_face_search_line_edit,
|
||||
model_type == ModelType.HUGGING_FACE)
|
||||
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
|
||||
(model_type == ModelType.WHISPER) or (
|
||||
model_type == ModelType.WHISPER_CPP) or (
|
||||
model_type == ModelType.FASTER_WHISPER))
|
||||
self.form_layout.setRowVisible(self.openai_access_token_edit,
|
||||
model_type == ModelType.OPEN_AI_WHISPER_API)
|
||||
|
||||
def on_model_type_changed(self, model_type: ModelType):
|
||||
self.transcription_options.model.model_type = model_type
|
||||
self.reset_visible_rows()
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_whisper_model_size_changed(self, text: str):
|
||||
model_size = WhisperModelSize(text.lower())
|
||||
self.transcription_options.model.whisper_model_size = model_size
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_hugging_face_model_changed(self, model: str):
|
||||
self.transcription_options.model.hugging_face_model_id = model
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
|
@ -134,10 +134,8 @@ class TranscriptionViewerWidget(QWidget):
|
|||
def on_menu_triggered(self, action: QAction):
|
||||
output_format = OutputFormat[action.text()]
|
||||
|
||||
default_path = get_default_output_file_path(
|
||||
task=self.transcription_task.transcription_options.task,
|
||||
input_file_path=self.transcription_task.file_path,
|
||||
output_format=output_format)
|
||||
default_path = get_default_output_file_path(task=self.transcription_task,
|
||||
output_format=output_format)
|
||||
|
||||
(output_file_path, nil) = QFileDialog.getSaveFileName(self, _('Save File'),
|
||||
default_path,
|
||||
|
|
|
|||
|
|
@ -14,16 +14,21 @@ from pytestqt.qtbot import QtBot
|
|||
|
||||
from buzz.__version__ import VERSION
|
||||
from buzz.cache import TasksCache
|
||||
from buzz.gui import (AdvancedSettingsDialog, AudioDevicesComboBox, FileTranscriberWidget,
|
||||
LanguagesComboBox, MainWindow,
|
||||
RecordingTranscriberWidget,
|
||||
TemperatureValidator, HuggingFaceSearchLineEdit,
|
||||
TranscriptionOptionsGroupBox)
|
||||
from buzz.gui import (AudioDevicesComboBox, MainWindow,
|
||||
RecordingTranscriberWidget)
|
||||
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
|
||||
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
|
||||
from buzz.widgets.transcriber.hugging_face_search_line_edit import \
|
||||
HuggingFaceSearchLineEdit
|
||||
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
|
||||
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
|
||||
from buzz.widgets.about_dialog import AboutDialog
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
|
||||
TranscriptionOptions)
|
||||
from buzz.widgets.transcriber.transcription_options_group_box import \
|
||||
TranscriptionOptionsGroupBox
|
||||
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
|
||||
from tests.mock_sounddevice import MockInputStream, mock_query_devices
|
||||
from .mock_qt import MockNetworkAccessManager, MockNetworkReply
|
||||
|
|
@ -278,32 +283,6 @@ def clear_settings():
|
|||
settings.clear()
|
||||
|
||||
|
||||
class TestFileTranscriberWidget:
|
||||
def test_should_set_window_title(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_paths=['testdata/whisper-french.mp3'], parent=None)
|
||||
qtbot.add_widget(widget)
|
||||
assert widget.windowTitle() == 'whisper-french.mp3'
|
||||
|
||||
def test_should_emit_triggered_event(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_paths=['testdata/whisper-french.mp3'], parent=None)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
mock_triggered = Mock()
|
||||
widget.triggered.connect(mock_triggered)
|
||||
|
||||
with qtbot.wait_signal(widget.triggered, timeout=30 * 1000):
|
||||
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
|
||||
|
||||
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
|
||||
0][0]
|
||||
assert transcription_options.language is None
|
||||
assert file_transcription_options.file_paths == [
|
||||
'testdata/whisper-french.mp3']
|
||||
assert len(model_path) > 0
|
||||
|
||||
|
||||
class TestAboutDialog:
|
||||
def test_should_check_for_updates(self, qtbot: QtBot):
|
||||
reply = MockNetworkReply(data={'name': 'v' + VERSION})
|
||||
|
|
|
|||
|
|
@ -89,14 +89,39 @@ class TestWhisperCppFileTranscriber:
|
|||
|
||||
|
||||
class TestWhisperFileTranscriber:
|
||||
@pytest.mark.parametrize(
|
||||
'output_format,expected_file_path,default_output_file_name',
|
||||
[
|
||||
(OutputFormat.SRT, '/a/b/c-translate--Whisper-tiny.srt', '{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}'),
|
||||
])
|
||||
def test_default_output_file2(self, output_format: OutputFormat, expected_file_path: str, default_output_file_name: str):
|
||||
file_path = get_default_output_file_path(
|
||||
task=FileTranscriptionTask(
|
||||
file_path='/a/b/c.mp4',
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name=default_output_file_name),
|
||||
model_path=''),
|
||||
output_format=output_format)
|
||||
assert file_path == expected_file_path
|
||||
|
||||
def test_default_output_file(self):
|
||||
srt = get_default_output_file_path(
|
||||
Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.TXT)
|
||||
task=FileTranscriptionTask(
|
||||
file_path='/a/b/c.mp4',
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'),
|
||||
model_path=''),
|
||||
output_format=OutputFormat.TXT)
|
||||
assert srt.startswith('/a/b/c (Translated on ')
|
||||
assert srt.endswith('.txt')
|
||||
|
||||
srt = get_default_output_file_path(
|
||||
Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.SRT)
|
||||
task=FileTranscriptionTask(
|
||||
file_path='/a/b/c.mp4',
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'),
|
||||
model_path=''),
|
||||
output_format=OutputFormat.SRT)
|
||||
assert srt.startswith('/a/b/c (Translated on ')
|
||||
assert srt.endswith('.srt')
|
||||
|
||||
|
|
|
|||
32
tests/widgets/file_transcriber_widget_test.py
Normal file
32
tests/widgets/file_transcriber_widget_test.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from PyQt6.QtCore import Qt
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
|
||||
|
||||
|
||||
class TestFileTranscriberWidget:
|
||||
def test_should_set_window_title(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None)
|
||||
qtbot.add_widget(widget)
|
||||
assert widget.windowTitle() == 'whisper-french.mp3'
|
||||
|
||||
def test_should_emit_triggered_event(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
mock_triggered = Mock()
|
||||
widget.triggered.connect(mock_triggered)
|
||||
|
||||
with qtbot.wait_signal(widget.triggered, timeout=30 * 1000):
|
||||
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
|
||||
|
||||
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
|
||||
0][0]
|
||||
assert transcription_options.language is None
|
||||
assert file_transcription_options.file_paths == [
|
||||
'testdata/whisper-french.mp3']
|
||||
assert len(model_path) > 0
|
||||
|
|
@ -10,7 +10,8 @@ from buzz.widgets.preferences_dialog.general_preferences_widget import \
|
|||
|
||||
class TestGeneralPreferencesWidget:
|
||||
def test_should_disable_test_button_if_no_api_key(self, qtbot):
|
||||
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(''))
|
||||
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(''),
|
||||
default_export_file_name='')
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
test_button = widget.findChild(QPushButton)
|
||||
|
|
@ -26,7 +27,9 @@ class TestGeneralPreferencesWidget:
|
|||
assert test_button.isEnabled()
|
||||
|
||||
def test_should_test_openai_api_key(self, qtbot):
|
||||
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store('wrong-api-key'))
|
||||
widget = GeneralPreferencesWidget(
|
||||
keyring_store=self.get_keyring_store('wrong-api-key'),
|
||||
default_export_file_name='')
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
test_button = widget.findChild(QPushButton)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog
|
|||
|
||||
class TestPreferencesDialog:
|
||||
def test_create(self, qtbot: QtBot):
|
||||
dialog = PreferencesDialog(shortcuts={})
|
||||
dialog = PreferencesDialog(shortcuts={}, default_export_file_name='')
|
||||
qtbot.add_widget(dialog)
|
||||
|
||||
assert dialog.windowTitle() == 'Preferences'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue