diff --git a/build.py b/build.py index e7eb59eb..00725f9a 100644 --- a/build.py +++ b/build.py @@ -2,7 +2,7 @@ import subprocess def build(setup_kwargs): - subprocess.call(['make', 'buzz/whisper_cpp.py']) + subprocess.call(["make", "buzz/whisper_cpp.py"]) if __name__ == "__main__": diff --git a/buzz/action.py b/buzz/action.py index e0536670..9a9d8675 100644 --- a/buzz/action.py +++ b/buzz/action.py @@ -4,7 +4,10 @@ from PyQt6.QtGui import QAction, QKeySequence class Action(QAction): - def setShortcut(self, shortcut: typing.Union['QKeySequence', 'QKeySequence.StandardKey', str, int]) -> None: + def setShortcut( + self, + shortcut: typing.Union["QKeySequence", "QKeySequence.StandardKey", str, int], + ) -> None: super().setShortcut(shortcut) self.setToolTip(Action.get_tooltip(self)) diff --git a/buzz/assets.py b/buzz/assets.py index bba32242..05fcf914 100644 --- a/buzz/assets.py +++ b/buzz/assets.py @@ -3,6 +3,6 @@ import sys def get_asset_path(path: str): - if getattr(sys, 'frozen', False): + if getattr(sys, "frozen", False): return os.path.join(os.path.dirname(sys.executable), path) - return os.path.join(os.path.dirname(__file__), '..', path) + return os.path.join(os.path.dirname(__file__), "..", path) diff --git a/buzz/buzz.py b/buzz/buzz.py index b414e8c2..d0f78cfe 100644 --- a/buzz/buzz.py +++ b/buzz/buzz.py @@ -9,7 +9,7 @@ from typing import TextIO from appdirs import user_log_dir # Check for segfaults if not running in frozen mode -if getattr(sys, 'frozen', False) is False: +if getattr(sys, "frozen", False) is False: faulthandler.enable() # Sets stderr to no-op TextIO when None (run as Windows GUI). @@ -19,30 +19,35 @@ if sys.stderr is None: # Adds the current directory to the PATH, so the ffmpeg binary get picked up: # https://stackoverflow.com/a/44352931/9830227 -app_dir = getattr(sys, '_MEIPASS', os.path.dirname( - os.path.abspath(__file__))) +app_dir = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) os.environ["PATH"] += os.pathsep + app_dir # Add the app directory to the DLL list: https://stackoverflow.com/a/64303856 -if platform.system() == 'Windows': +if platform.system() == "Windows": os.add_dll_directory(app_dir) def main(): - if platform.system() == 'Linux': - multiprocessing.set_start_method('spawn') + if platform.system() == "Linux": + multiprocessing.set_start_method("spawn") # Fixes opening new window when app has been frozen on Windows: # https://stackoverflow.com/a/33979091 multiprocessing.freeze_support() - log_dir = user_log_dir(appname='Buzz') + log_dir = user_log_dir(appname="Buzz") os.makedirs(log_dir, exist_ok=True) - log_format = "[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s" - logging.basicConfig(filename=os.path.join(log_dir, 'logs.txt'), level=logging.DEBUG, format=log_format) + log_format = ( + "[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s" + ) + logging.basicConfig( + filename=os.path.join(log_dir, "logs.txt"), + level=logging.DEBUG, + format=log_format, + ) - if getattr(sys, 'frozen', False) is False: + if getattr(sys, "frozen", False) is False: stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setLevel(logging.DEBUG) stdout_handler.setFormatter(logging.Formatter(log_format)) diff --git a/buzz/cache.py b/buzz/cache.py index 26ae35a9..7ca529c2 100644 --- a/buzz/cache.py +++ b/buzz/cache.py @@ -9,11 +9,11 @@ from .transcriber import FileTranscriptionTask class TasksCache: - def __init__(self, cache_dir=user_cache_dir('Buzz')): + def __init__(self, cache_dir=user_cache_dir("Buzz")): os.makedirs(cache_dir, exist_ok=True) self.cache_dir = cache_dir - self.pickle_cache_file_path = os.path.join(cache_dir, 'tasks') - self.tasks_list_file_path = os.path.join(cache_dir, 'tasks.json') + self.pickle_cache_file_path = os.path.join(cache_dir, "tasks") + self.tasks_list_file_path = os.path.join(cache_dir, "tasks.json") def save(self, tasks: List[FileTranscriptionTask]): self.save_json_tasks(tasks=tasks) @@ -23,16 +23,20 @@ class TasksCache: return self.load_json_tasks() try: - with open(self.pickle_cache_file_path, 'rb') as file: + with open(self.pickle_cache_file_path, "rb") as file: return pickle.load(file) except FileNotFoundError: return [] - except (pickle.UnpicklingError, AttributeError, ValueError): # delete corrupted cache + except ( + pickle.UnpicklingError, + AttributeError, + ValueError, + ): # delete corrupted cache os.remove(self.pickle_cache_file_path) return [] def load_json_tasks(self) -> List[FileTranscriptionTask]: - with open(self.tasks_list_file_path, 'r') as file: + with open(self.tasks_list_file_path, "r") as file: task_ids = json.load(file) tasks = [] @@ -57,7 +61,7 @@ class TasksCache: file.write(json_str) def get_task_path(self, task_id: int): - path = os.path.join(self.cache_dir, 'transcriptions', f'{task_id}.json') + path = os.path.join(self.cache_dir, "transcriptions", f"{task_id}.json") os.makedirs(os.path.dirname(path), exist_ok=True) return path diff --git a/buzz/cli.py b/buzz/cli.py index ad7e0fec..8f5ecff5 100644 --- a/buzz/cli.py +++ b/buzz/cli.py @@ -7,8 +7,14 @@ from PyQt6.QtCore import QCommandLineParser, QCommandLineOption from buzz.gui import Application from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel from buzz.store.keyring_store import KeyringStore -from buzz.transcriber import Task, FileTranscriptionTask, FileTranscriptionOptions, TranscriptionOptions, LANGUAGES, \ - OutputFormat +from buzz.transcriber import ( + Task, + FileTranscriptionTask, + FileTranscriptionOptions, + TranscriptionOptions, + LANGUAGES, + OutputFormat, +) class CommandLineError(Exception): @@ -17,11 +23,11 @@ class CommandLineError(Exception): class CommandLineModelType(enum.Enum): - WHISPER = 'whisper' - WHISPER_CPP = 'whispercpp' - HUGGING_FACE = 'huggingface' - FASTER_WHISPER = 'fasterwhisper' - OPEN_AI_WHISPER_API = 'openaiapi' + WHISPER = "whisper" + WHISPER_CPP = "whispercpp" + HUGGING_FACE = "huggingface" + FASTER_WHISPER = "fasterwhisper" + OPEN_AI_WHISPER_API = "openaiapi" def parse_command_line(app: Application): @@ -29,13 +35,13 @@ def parse_command_line(app: Application): try: parse(app, parser) except CommandLineError as exc: - print(f'Error: {str(exc)}\n', file=sys.stderr) + print(f"Error: {str(exc)}\n", file=sys.stderr) print(parser.helpText()) sys.exit(1) def parse(app: Application, parser: QCommandLineParser): - parser.addPositionalArgument('', 'One of the following commands:\n- add') + parser.addPositionalArgument("", "One of the following commands:\n- add") parser.parse(app.arguments()) args = parser.positionalArguments() @@ -50,36 +56,63 @@ def parse(app: Application, parser: QCommandLineParser): if command == "add": parser.clearPositionalArguments() - parser.addPositionalArgument('files', 'Input file paths', '[file file file...]') + parser.addPositionalArgument("files", "Input file paths", "[file file file...]") - task_option = QCommandLineOption(['t', 'task'], - f'The task to perform. Allowed: {join_values(Task)}. Default: {Task.TRANSCRIBE.value}.', - 'task', - Task.TRANSCRIBE.value) - model_type_option = QCommandLineOption(['m', 'model-type'], - f'Model type. Allowed: {join_values(CommandLineModelType)}. Default: {CommandLineModelType.WHISPER.value}.', - 'model-type', - CommandLineModelType.WHISPER.value) - model_size_option = QCommandLineOption(['s', 'model-size'], - f'Model size. Use only when --model-type is whisper, whispercpp, or fasterwhisper. Allowed: {join_values(WhisperModelSize)}. Default: {WhisperModelSize.TINY.value}.', - 'model-size', WhisperModelSize.TINY.value) - hugging_face_model_id_option = QCommandLineOption(['hfid'], - f'Hugging Face model ID. Use only when --model-type is huggingface. Example: "openai/whisper-tiny"', - 'id') - language_option = QCommandLineOption(['l', 'language'], - f'Language code. Allowed: {", ".join(sorted([k + " (" + LANGUAGES[k].title() + ")" for k in LANGUAGES]))}. Leave empty to detect language.', - 'code', '') - initial_prompt_option = QCommandLineOption(['p', 'prompt'], f'Initial prompt', 'prompt', '') - open_ai_access_token_option = QCommandLineOption('openai-token', - f'OpenAI access token. Use only when --model-type is {CommandLineModelType.OPEN_AI_WHISPER_API.value}. Defaults to your previously saved access token, if one exists.', - 'token') - srt_option = QCommandLineOption(['srt'], 'Output result in an SRT file.') - vtt_option = QCommandLineOption(['vtt'], 'Output result in a VTT file.') - txt_option = QCommandLineOption('txt', 'Output result in a TXT file.') + task_option = QCommandLineOption( + ["t", "task"], + f"The task to perform. Allowed: {join_values(Task)}. Default: {Task.TRANSCRIBE.value}.", + "task", + Task.TRANSCRIBE.value, + ) + model_type_option = QCommandLineOption( + ["m", "model-type"], + f"Model type. Allowed: {join_values(CommandLineModelType)}. Default: {CommandLineModelType.WHISPER.value}.", + "model-type", + CommandLineModelType.WHISPER.value, + ) + model_size_option = QCommandLineOption( + ["s", "model-size"], + f"Model size. Use only when --model-type is whisper, whispercpp, or fasterwhisper. Allowed: {join_values(WhisperModelSize)}. Default: {WhisperModelSize.TINY.value}.", + "model-size", + WhisperModelSize.TINY.value, + ) + hugging_face_model_id_option = QCommandLineOption( + ["hfid"], + f'Hugging Face model ID. Use only when --model-type is huggingface. Example: "openai/whisper-tiny"', + "id", + ) + language_option = QCommandLineOption( + ["l", "language"], + f'Language code. Allowed: {", ".join(sorted([k + " (" + LANGUAGES[k].title() + ")" for k in LANGUAGES]))}. Leave empty to detect language.', + "code", + "", + ) + initial_prompt_option = QCommandLineOption( + ["p", "prompt"], f"Initial prompt", "prompt", "" + ) + open_ai_access_token_option = QCommandLineOption( + "openai-token", + f"OpenAI access token. Use only when --model-type is {CommandLineModelType.OPEN_AI_WHISPER_API.value}. Defaults to your previously saved access token, if one exists.", + "token", + ) + srt_option = QCommandLineOption(["srt"], "Output result in an SRT file.") + vtt_option = QCommandLineOption(["vtt"], "Output result in a VTT file.") + txt_option = QCommandLineOption("txt", "Output result in a TXT file.") parser.addOptions( - [task_option, model_type_option, model_size_option, hugging_face_model_id_option, language_option, - initial_prompt_option, open_ai_access_token_option, srt_option, vtt_option, txt_option]) + [ + task_option, + model_type_option, + model_size_option, + hugging_face_model_id_option, + language_option, + initial_prompt_option, + open_ai_access_token_option, + srt_option, + vtt_option, + txt_option, + ] + ) parser.addHelpOption() parser.addVersionOption() @@ -89,7 +122,7 @@ def parse(app: Application, parser: QCommandLineParser): # slice after first argument, the command file_paths = parser.positionalArguments()[1:] if len(file_paths) == 0: - raise CommandLineError('No input files') + raise CommandLineError("No input files") task = parse_enum_option(task_option, parser, Task) @@ -98,21 +131,29 @@ def parse(app: Application, parser: QCommandLineParser): hugging_face_model_id = parser.value(hugging_face_model_id_option) - if hugging_face_model_id == '' and model_type == CommandLineModelType.HUGGING_FACE: - raise CommandLineError('--hfid is required when --model-type is huggingface') + if ( + hugging_face_model_id == "" + and model_type == CommandLineModelType.HUGGING_FACE + ): + raise CommandLineError( + "--hfid is required when --model-type is huggingface" + ) - model = TranscriptionModel(model_type=ModelType[model_type.name], whisper_model_size=model_size, - hugging_face_model_id=hugging_face_model_id) + model = TranscriptionModel( + model_type=ModelType[model_type.name], + whisper_model_size=model_size, + hugging_face_model_id=hugging_face_model_id, + ) model_path = model.get_local_model_path() if model_path is None: - raise CommandLineError('Model not found') + raise CommandLineError("Model not found") language = parser.value(language_option) - if language == '': + if language == "": language = None elif LANGUAGES.get(language) is None: - raise CommandLineError('Invalid language option') + raise CommandLineError("Invalid language option") initial_prompt = parser.value(initial_prompt_option) @@ -125,33 +166,49 @@ def parse(app: Application, parser: QCommandLineParser): output_formats.add(OutputFormat.TXT) openai_access_token = parser.value(open_ai_access_token_option) - if model.model_type == ModelType.OPEN_AI_WHISPER_API and openai_access_token == '': - openai_access_token = KeyringStore().get_password(key=KeyringStore.Key.OPENAI_API_KEY) + if ( + model.model_type == ModelType.OPEN_AI_WHISPER_API + and openai_access_token == "" + ): + openai_access_token = KeyringStore().get_password( + key=KeyringStore.Key.OPENAI_API_KEY + ) - if openai_access_token == '': - raise CommandLineError('No OpenAI access token found') + if openai_access_token == "": + raise CommandLineError("No OpenAI access token found") - transcription_options = TranscriptionOptions(model=model, task=task, language=language, - initial_prompt=initial_prompt, - openai_access_token=openai_access_token) - file_transcription_options = FileTranscriptionOptions(file_paths=file_paths, output_formats=output_formats) + transcription_options = TranscriptionOptions( + model=model, + task=task, + language=language, + initial_prompt=initial_prompt, + openai_access_token=openai_access_token, + ) + file_transcription_options = FileTranscriptionOptions( + file_paths=file_paths, output_formats=output_formats + ) for file_path in file_paths: - transcription_task = FileTranscriptionTask(file_path=file_path, model_path=model_path, - transcription_options=transcription_options, - file_transcription_options=file_transcription_options) + transcription_task = FileTranscriptionTask( + file_path=file_path, + model_path=model_path, + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + ) app.add_task(transcription_task) T = typing.TypeVar("T", bound=enum.Enum) -def parse_enum_option(option: QCommandLineOption, parser: QCommandLineParser, enum_class: typing.Type[T]) -> T: +def parse_enum_option( + option: QCommandLineOption, parser: QCommandLineParser, enum_class: typing.Type[T] +) -> T: try: return enum_class(parser.value(option)) except ValueError: - raise CommandLineError(f'Invalid value for --{option.names()[-1]} option.') + raise CommandLineError(f"Invalid value for --{option.names()[-1]} option.") def join_values(enum_class: typing.Type[enum.Enum]) -> str: - return ', '.join([v.value for v in enum_class]) + return ", ".join([v.value for v in enum_class]) diff --git a/buzz/dialogs.py b/buzz/dialogs.py index b1b7ff59..5f9354e9 100644 --- a/buzz/dialogs.py +++ b/buzz/dialogs.py @@ -2,9 +2,10 @@ from PyQt6.QtWidgets import QWidget, QMessageBox def show_model_download_error_dialog(parent: QWidget, error: str): - message = parent.tr( - 'An error occurred while loading the Whisper model') + \ - f": {error}{'' if error.endswith('.') else '.'}" + \ - parent.tr("Please retry or check the application logs for more information.") + message = ( + parent.tr("An error occurred while loading the Whisper model") + + f": {error}{'' if error.endswith('.') else '.'}" + + parent.tr("Please retry or check the application logs for more information.") + ) - QMessageBox.critical(parent, '', message) + QMessageBox.critical(parent, "", message) diff --git a/buzz/file_transcriber_queue_worker.py b/buzz/file_transcriber_queue_worker.py index 82783f06..cfa76abd 100644 --- a/buzz/file_transcriber_queue_worker.py +++ b/buzz/file_transcriber_queue_worker.py @@ -7,8 +7,14 @@ from typing import Optional, Tuple, List from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot from buzz.model_loader import ModelType -from buzz.transcriber import FileTranscriptionTask, FileTranscriber, WhisperCppFileTranscriber, \ - OpenAIWhisperAPIFileTranscriber, WhisperFileTranscriber, Segment +from buzz.transcriber import ( + FileTranscriptionTask, + FileTranscriber, + WhisperCppFileTranscriber, + OpenAIWhisperAPIFileTranscriber, + WhisperFileTranscriber, + Segment, +) class FileTranscriberQueueWorker(QObject): @@ -26,7 +32,7 @@ class FileTranscriberQueueWorker(QObject): @pyqtSlot() def run(self): - logging.debug('Waiting for next transcription task') + logging.debug("Waiting for next transcription task") # Get next non-canceled task from queue while True: @@ -42,38 +48,37 @@ class FileTranscriberQueueWorker(QObject): break - logging.debug('Starting next transcription task') + logging.debug("Starting next transcription task") model_type = self.current_task.transcription_options.model.model_type if model_type == ModelType.WHISPER_CPP: - self.current_transcriber = WhisperCppFileTranscriber( - task=self.current_task) + self.current_transcriber = WhisperCppFileTranscriber(task=self.current_task) elif model_type == ModelType.OPEN_AI_WHISPER_API: - self.current_transcriber = OpenAIWhisperAPIFileTranscriber(task=self.current_task) - elif model_type == ModelType.HUGGING_FACE or \ - model_type == ModelType.WHISPER or \ - model_type == ModelType.FASTER_WHISPER: + self.current_transcriber = OpenAIWhisperAPIFileTranscriber( + task=self.current_task + ) + elif ( + model_type == ModelType.HUGGING_FACE + or model_type == ModelType.WHISPER + or model_type == ModelType.FASTER_WHISPER + ): self.current_transcriber = WhisperFileTranscriber(task=self.current_task) else: - raise Exception(f'Unknown model type: {model_type}') + raise Exception(f"Unknown model type: {model_type}") self.current_transcriber_thread = QThread(self) self.current_transcriber.moveToThread(self.current_transcriber_thread) - self.current_transcriber_thread.started.connect( - self.current_transcriber.run) - self.current_transcriber.completed.connect( - self.current_transcriber_thread.quit) - self.current_transcriber.error.connect( - self.current_transcriber_thread.quit) + self.current_transcriber_thread.started.connect(self.current_transcriber.run) + self.current_transcriber.completed.connect(self.current_transcriber_thread.quit) + self.current_transcriber.error.connect(self.current_transcriber_thread.quit) - self.current_transcriber.completed.connect( - self.current_transcriber.deleteLater) - self.current_transcriber.error.connect( - self.current_transcriber.deleteLater) + self.current_transcriber.completed.connect(self.current_transcriber.deleteLater) + self.current_transcriber.error.connect(self.current_transcriber.deleteLater) self.current_transcriber_thread.finished.connect( - self.current_transcriber_thread.deleteLater) + self.current_transcriber_thread.deleteLater + ) self.current_transcriber.progress.connect(self.on_task_progress) self.current_transcriber.error.connect(self.on_task_error) @@ -104,7 +109,10 @@ class FileTranscriberQueueWorker(QObject): @pyqtSlot(Exception) def on_task_error(self, error: Exception): - if self.current_task is not None and self.current_task.id not in self.canceled_tasks: + if ( + self.current_task is not None + and self.current_task.id not in self.canceled_tasks + ): self.current_task.status = FileTranscriptionTask.Status.FAILED self.current_task.error = str(error) self.task_updated.emit(self.current_task) diff --git a/buzz/gui.py b/buzz/gui.py index 496d024a..c122d810 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -5,13 +5,23 @@ from typing import Dict, List, Optional, Tuple import sounddevice from PyQt6 import QtGui -from PyQt6.QtCore import (Qt, QThread, - pyqtSignal, QModelIndex, QThreadPool) -from PyQt6.QtGui import (QCloseEvent, QIcon, - QKeySequence, QTextCursor, QPainter, QColor) -from PyQt6.QtWidgets import (QApplication, QComboBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit, - QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QFormLayout, - QSizePolicy) +from PyQt6.QtCore import Qt, QThread, pyqtSignal, QModelIndex, QThreadPool +from PyQt6.QtGui import QCloseEvent, QIcon, QKeySequence, QTextCursor, QPainter, QColor +from PyQt6.QtWidgets import ( + QApplication, + QComboBox, + QFileDialog, + QLabel, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QPushButton, + QVBoxLayout, + QHBoxLayout, + QWidget, + QFormLayout, + QSizePolicy, +) from buzz.cache import TasksCache from .__version__ import VERSION @@ -20,25 +30,35 @@ from .assets import get_asset_path from .dialogs import show_model_download_error_dialog from .widgets.icon import Icon, BUZZ_ICON_PATH from .locale import _ -from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, \ - ModelDownloader +from .model_loader import ( + WhisperModelSize, + ModelType, + TranscriptionModel, + ModelDownloader, +) from .recording import RecordingAmplitudeListener from .settings.settings import Settings, APP_NAME from .settings.shortcut import Shortcut from .settings.shortcut_settings import ShortcutSettings from .store.keyring_store import KeyringStore -from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Task, - TranscriptionOptions, - FileTranscriptionTask, LOADED_WHISPER_DLL, - DEFAULT_WHISPER_TEMPERATURE) +from .transcriber import ( + SUPPORTED_OUTPUT_FORMATS, + FileTranscriptionOptions, + Task, + TranscriptionOptions, + FileTranscriptionTask, + LOADED_WHISPER_DLL, + DEFAULT_WHISPER_TEMPERATURE, +) from .recording_transcriber import RecordingTranscriber from .file_transcriber_queue_worker import FileTranscriberQueueWorker from .widgets.menu_bar import MenuBar from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog from .widgets.toolbar import ToolBar from .widgets.transcriber.file_transcriber_widget import FileTranscriberWidget -from .widgets.transcriber.transcription_options_group_box import \ - TranscriptionOptionsGroupBox +from .widgets.transcriber.transcription_options_group_box import ( + TranscriptionOptionsGroupBox, +) from .widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget from .widgets.transcription_viewer_widget import TranscriptionViewerWidget @@ -46,13 +66,17 @@ from .widgets.transcription_viewer_widget import TranscriptionViewerWidget class FormLabel(QLabel): def __init__(self, name: str, parent: Optional[QWidget], *args) -> None: super().__init__(name, parent, *args) - self.setStyleSheet('QLabel { text-align: right; }') - self.setAlignment(Qt.AlignmentFlag( - Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight)) + self.setStyleSheet("QLabel { text-align: right; }") + self.setAlignment( + Qt.AlignmentFlag( + Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight + ) + ) class AudioDevicesComboBox(QComboBox): """AudioDevicesComboBox displays a list of available audio input devices""" + device_changed = pyqtSignal(int) audio_devices: List[Tuple[int, str]] @@ -71,13 +95,18 @@ class AudioDevicesComboBox(QComboBox): def get_audio_devices(self) -> List[Tuple[int, str]]: try: devices: sounddevice.DeviceList = sounddevice.query_devices() - return [(device.get('index'), device.get('name')) - for device in devices if device.get('max_input_channels') > 0] + return [ + (device.get("index"), device.get("name")) + for device in devices + if device.get("max_input_channels") > 0 + ] except UnicodeDecodeError: QMessageBox.critical( - self, '', - 'An error occurred while loading your audio devices. Please check the application logs for more ' - 'information.') + self, + "", + "An error occurred while loading your audio devices. Please check the application logs for more " + "information.", + ) return [] def on_index_changed(self, index: int): @@ -107,14 +136,16 @@ class RecordButton(QPushButton): def __init__(self, parent: Optional[QWidget]) -> None: super().__init__(_("Record"), parent) self.setDefault(True) - self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)) + self.setSizePolicy( + QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + ) def set_stopped(self): - self.setText(_('Record')) + self.setText(_("Record")) self.setDefault(True) def set_recording(self): - self.setText(_('Stop')) + self.setText(_("Stop")) self.setDefault(False) @@ -143,11 +174,11 @@ class AudioMeterWidget(QWidget): self.AMPLITUDE_SCALE_FACTOR = 15 # scale the amplitudes such that 1/AMPLITUDE_SCALE_FACTOR will show all bars if self.palette().window().color().black() > 127: - self.BAR_INACTIVE_COLOR = QColor('#555') - self.BAR_ACTIVE_COLOR = QColor('#999') + self.BAR_INACTIVE_COLOR = QColor("#555") + self.BAR_ACTIVE_COLOR = QColor("#999") else: - self.BAR_INACTIVE_COLOR = QColor('#BBB') - self.BAR_ACTIVE_COLOR = QColor('#555') + self.BAR_INACTIVE_COLOR = QColor("#BBB") + self.BAR_ACTIVE_COLOR = QColor("#555") def paintEvent(self, event: QtGui.QPaintEvent) -> None: painter = QPainter(self) @@ -157,26 +188,38 @@ class AudioMeterWidget(QWidget): center_x = rect.center().x() num_bars_in_half = int((rect.width() / 2) / (self.BAR_MARGIN + self.BAR_WIDTH)) for i in range(num_bars_in_half): - is_bar_active = ((self.current_amplitude - self.MINIMUM_AMPLITUDE) * self.AMPLITUDE_SCALE_FACTOR) > ( - i / num_bars_in_half) - painter.setBrush(self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR) + is_bar_active = ( + (self.current_amplitude - self.MINIMUM_AMPLITUDE) + * self.AMPLITUDE_SCALE_FACTOR + ) > (i / num_bars_in_half) + painter.setBrush( + self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR + ) # draw to left - painter.drawRect(center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)), rect.top() + self.PADDING_TOP, - self.BAR_WIDTH, - rect.height() - self.PADDING_TOP) + painter.drawRect( + center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)), + rect.top() + self.PADDING_TOP, + self.BAR_WIDTH, + rect.height() - self.PADDING_TOP, + ) # draw to right - painter.drawRect(center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))), - rect.top() + self.PADDING_TOP, - self.BAR_WIDTH, rect.height() - self.PADDING_TOP) + painter.drawRect( + center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))), + rect.top() + self.PADDING_TOP, + self.BAR_WIDTH, + rect.height() - self.PADDING_TOP, + ) def update_amplitude(self, amplitude: float): - self.current_amplitude = max(amplitude, self.current_amplitude * self.SMOOTHING_FACTOR) + self.current_amplitude = max( + amplitude, self.current_amplitude * self.SMOOTHING_FACTOR + ) self.repaint() class RecordingTranscriberWidget(QWidget): - current_status: 'RecordingStatus' + current_status: "RecordingStatus" transcription_options: TranscriptionOptions selected_device_id: Optional[int] model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None @@ -190,7 +233,9 @@ class RecordingTranscriberWidget(QWidget): STOPPED = auto() RECORDING = auto() - def __init__(self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowType] = None) -> None: + def __init__( + self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowType] = None + ) -> None: super().__init__(parent) if flags is not None: @@ -199,42 +244,63 @@ class RecordingTranscriberWidget(QWidget): layout = QVBoxLayout(self) self.current_status = self.RecordingStatus.STOPPED - self.setWindowTitle(_('Live Recording')) + self.setWindowTitle(_("Live Recording")) self.settings = Settings() - default_language = self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, default_value='') + default_language = self.settings.value( + key=Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, default_value="" + ) self.transcription_options = TranscriptionOptions( - model=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_MODEL, default_value=TranscriptionModel( - model_type=ModelType.WHISPER_CPP if LOADED_WHISPER_DLL else ModelType.WHISPER, - whisper_model_size=WhisperModelSize.TINY)), - task=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE), - language=default_language if default_language != '' else None, - initial_prompt=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, default_value=''), - temperature=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE, - default_value=DEFAULT_WHISPER_TEMPERATURE), word_level_timings=False) + model=self.settings.value( + key=Settings.Key.RECORDING_TRANSCRIBER_MODEL, + default_value=TranscriptionModel( + model_type=ModelType.WHISPER_CPP + if LOADED_WHISPER_DLL + else ModelType.WHISPER, + whisper_model_size=WhisperModelSize.TINY, + ), + ), + task=self.settings.value( + key=Settings.Key.RECORDING_TRANSCRIBER_TASK, + default_value=Task.TRANSCRIBE, + ), + language=default_language if default_language != "" else None, + initial_prompt=self.settings.value( + key=Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, default_value="" + ), + temperature=self.settings.value( + key=Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE, + default_value=DEFAULT_WHISPER_TEMPERATURE, + ), + word_level_timings=False, + ) self.audio_devices_combo_box = AudioDevicesComboBox(self) - self.audio_devices_combo_box.device_changed.connect( - self.on_device_changed) + self.audio_devices_combo_box.device_changed.connect(self.on_device_changed) self.selected_device_id = self.audio_devices_combo_box.get_default_device_id() self.record_button = RecordButton(self) self.record_button.clicked.connect(self.on_record_button_clicked) self.text_box = TextDisplayBox(self) - self.text_box.setPlaceholderText(_('Click Record to begin...')) + self.text_box.setPlaceholderText(_("Click Record to begin...")) transcription_options_group_box = TranscriptionOptionsGroupBox( default_transcription_options=self.transcription_options, # Live transcription with OpenAI Whisper API not implemented - model_types=[model_type for model_type in ModelType if model_type is not ModelType.OPEN_AI_WHISPER_API], - parent=self) + model_types=[ + model_type + for model_type in ModelType + if model_type is not ModelType.OPEN_AI_WHISPER_API + ], + parent=self, + ) transcription_options_group_box.transcription_options_changed.connect( - self.on_transcription_options_changed) + self.on_transcription_options_changed + ) recording_options_layout = QFormLayout() - recording_options_layout.addRow( - _('Microphone:'), self.audio_devices_combo_box) + recording_options_layout.addRow(_("Microphone:"), self.audio_devices_combo_box) self.audio_meter_widget = AudioMeterWidget(self) @@ -252,7 +318,9 @@ class RecordingTranscriberWidget(QWidget): self.reset_recording_amplitude_listener() - def on_transcription_options_changed(self, transcription_options: TranscriptionOptions): + def on_transcription_options_changed( + self, transcription_options: TranscriptionOptions + ): self.transcription_options = transcription_options def on_device_changed(self, device_id: int): @@ -269,11 +337,16 @@ class RecordingTranscriberWidget(QWidget): # Get the device sample rate before starting the listener as the PortAudio function # fails if you try to get the device's settings while recording is in progress. - self.device_sample_rate = RecordingTranscriber.get_device_sample_rate(self.selected_device_id) + self.device_sample_rate = RecordingTranscriber.get_device_sample_rate( + self.selected_device_id + ) - self.recording_amplitude_listener = RecordingAmplitudeListener(input_device_index=self.selected_device_id, - parent=self) - self.recording_amplitude_listener.amplitude_changed.connect(self.on_recording_amplitude_changed) + self.recording_amplitude_listener = RecordingAmplitudeListener( + input_device_index=self.selected_device_id, parent=self + ) + self.recording_amplitude_listener.amplitude_changed.connect( + self.on_recording_amplitude_changed + ) self.recording_amplitude_listener.start_recording() def on_record_button_clicked(self): @@ -306,16 +379,19 @@ class RecordingTranscriberWidget(QWidget): self.transcription_thread = QThread() # TODO: make runnable - self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id, - sample_rate=self.device_sample_rate, - transcription_options=self.transcription_options, - model_path=model_path) + self.transcriber = RecordingTranscriber( + input_device_index=self.selected_device_id, + sample_rate=self.device_sample_rate, + transcription_options=self.transcription_options, + model_path=model_path, + ) self.transcriber.moveToThread(self.transcription_thread) self.transcription_thread.started.connect(self.transcriber.start) self.transcription_thread.finished.connect( - self.transcription_thread.deleteLater) + self.transcription_thread.deleteLater + ) self.transcriber.transcription.connect(self.on_next_transcription) @@ -334,12 +410,16 @@ class RecordingTranscriberWidget(QWidget): if self.model_download_progress_dialog is None: self.model_download_progress_dialog = ModelDownloadProgressDialog( - model_type=self.transcription_options.model.model_type, parent=self) + model_type=self.transcription_options.model.model_type, parent=self + ) self.model_download_progress_dialog.canceled.connect( - self.on_cancel_model_progress_dialog) + self.on_cancel_model_progress_dialog + ) if self.model_download_progress_dialog is not None: - self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size) + self.model_download_progress_dialog.set_value( + fraction_completed=current_size / total_size + ) def set_recording_status_stopped(self): self.record_button.set_stopped() @@ -357,7 +437,7 @@ class RecordingTranscriberWidget(QWidget): if len(text) > 0: self.text_box.moveCursor(QTextCursor.MoveOperation.End) if len(self.text_box.toPlainText()) > 0: - self.text_box.insertPlainText('\n\n') + self.text_box.insertPlainText("\n\n") self.text_box.insertPlainText(text) self.text_box.moveCursor(QTextCursor.MoveOperation.End) @@ -374,9 +454,15 @@ class RecordingTranscriberWidget(QWidget): self.reset_record_button() self.set_recording_status_stopped() QMessageBox.critical( - self, '', - _('An error occurred while starting a new recording:') + error + '. ' + - _('Please check your audio devices or check the application logs for more information.')) + self, + "", + _("An error occurred while starting a new recording:") + + error + + ". " + + _( + "Please check your audio devices or check the application logs for more information." + ), + ) def on_cancel_model_progress_dialog(self): if self.model_loader is not None: @@ -392,7 +478,7 @@ class RecordingTranscriberWidget(QWidget): def reset_recording_controls(self): # Clear text box placeholder because the first chunk takes a while to process - self.text_box.setPlaceholderText('') + self.text_box.setPlaceholderText("") self.reset_record_button() self.reset_model_download() @@ -411,49 +497,72 @@ class RecordingTranscriberWidget(QWidget): self.recording_amplitude_listener.stop_recording() self.recording_amplitude_listener.deleteLater() - self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, self.transcription_options.language) - self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_TASK, self.transcription_options.task) - self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE, self.transcription_options.temperature) - self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, - self.transcription_options.initial_prompt) - self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_MODEL, self.transcription_options.model) + self.settings.set_value( + Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, + self.transcription_options.language, + ) + self.settings.set_value( + Settings.Key.RECORDING_TRANSCRIBER_TASK, self.transcription_options.task + ) + self.settings.set_value( + Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE, + self.transcription_options.temperature, + ) + self.settings.set_value( + Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, + self.transcription_options.initial_prompt, + ) + self.settings.set_value( + Settings.Key.RECORDING_TRANSCRIBER_MODEL, self.transcription_options.model + ) return super().closeEvent(event) -RECORD_ICON_PATH = get_asset_path('assets/mic_FILL0_wght700_GRAD0_opsz48.svg') -EXPAND_ICON_PATH = get_asset_path('assets/open_in_full_FILL0_wght700_GRAD0_opsz48.svg') -ADD_ICON_PATH = get_asset_path('assets/add_FILL0_wght700_GRAD0_opsz48.svg') -TRASH_ICON_PATH = get_asset_path('assets/delete_FILL0_wght700_GRAD0_opsz48.svg') -CANCEL_ICON_PATH = get_asset_path('assets/cancel_FILL0_wght700_GRAD0_opsz48.svg') +RECORD_ICON_PATH = get_asset_path("assets/mic_FILL0_wght700_GRAD0_opsz48.svg") +EXPAND_ICON_PATH = get_asset_path("assets/open_in_full_FILL0_wght700_GRAD0_opsz48.svg") +ADD_ICON_PATH = get_asset_path("assets/add_FILL0_wght700_GRAD0_opsz48.svg") +TRASH_ICON_PATH = get_asset_path("assets/delete_FILL0_wght700_GRAD0_opsz48.svg") +CANCEL_ICON_PATH = get_asset_path("assets/cancel_FILL0_wght700_GRAD0_opsz48.svg") class MainWindowToolbar(ToolBar): new_transcription_action_triggered: pyqtSignal open_transcript_action_triggered: pyqtSignal clear_history_action_triggered: pyqtSignal - ICON_LIGHT_THEME_BACKGROUND = '#555' - ICON_DARK_THEME_BACKGROUND = '#AAA' + ICON_LIGHT_THEME_BACKGROUND = "#555" + ICON_DARK_THEME_BACKGROUND = "#AAA" def __init__(self, shortcuts: Dict[str, str], parent: Optional[QWidget]): super().__init__(parent) - self.record_action = Action(Icon(RECORD_ICON_PATH, self), _('Record'), self) + self.record_action = Action(Icon(RECORD_ICON_PATH, self), _("Record"), self) self.record_action.triggered.connect(self.on_record_action_triggered) - self.new_transcription_action = Action(Icon(ADD_ICON_PATH, self), _('New Transcription'), self) - self.new_transcription_action_triggered = self.new_transcription_action.triggered + self.new_transcription_action = Action( + Icon(ADD_ICON_PATH, self), _("New Transcription"), self + ) + self.new_transcription_action_triggered = ( + self.new_transcription_action.triggered + ) - self.open_transcript_action = Action(Icon(EXPAND_ICON_PATH, self), - _('Open Transcript'), self) + self.open_transcript_action = Action( + Icon(EXPAND_ICON_PATH, self), _("Open Transcript"), self + ) self.open_transcript_action_triggered = self.open_transcript_action.triggered self.open_transcript_action.setDisabled(True) - self.stop_transcription_action = Action(Icon(CANCEL_ICON_PATH, self), _('Cancel Transcription'), self) - self.stop_transcription_action_triggered = self.stop_transcription_action.triggered + self.stop_transcription_action = Action( + Icon(CANCEL_ICON_PATH, self), _("Cancel Transcription"), self + ) + self.stop_transcription_action_triggered = ( + self.stop_transcription_action.triggered + ) self.stop_transcription_action.setDisabled(True) - self.clear_history_action = Action(Icon(TRASH_ICON_PATH, self), _('Clear History'), self) + self.clear_history_action = Action( + Icon(TRASH_ICON_PATH, self), _("Clear History"), self + ) self.clear_history_action_triggered = self.clear_history_action.triggered self.clear_history_action.setDisabled(True) @@ -461,21 +570,38 @@ class MainWindowToolbar(ToolBar): self.addAction(self.record_action) self.addSeparator() - self.addActions([self.new_transcription_action, self.open_transcript_action, self.stop_transcription_action, - self.clear_history_action]) + self.addActions( + [ + self.new_transcription_action, + self.open_transcript_action, + self.stop_transcription_action, + self.clear_history_action, + ] + ) self.setMovable(False) self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly) def set_shortcuts(self, shortcuts: Dict[str, str]): - self.record_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.OPEN_RECORD_WINDOW.name])) - self.new_transcription_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name])) + self.record_action.setShortcut( + QKeySequence.fromString(shortcuts[Shortcut.OPEN_RECORD_WINDOW.name]) + ) + self.new_transcription_action.setShortcut( + QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name]) + ) self.open_transcript_action.setShortcut( - QKeySequence.fromString(shortcuts[Shortcut.OPEN_TRANSCRIPT_EDITOR.name])) - self.stop_transcription_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.STOP_TRANSCRIPTION.name])) - self.clear_history_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.CLEAR_HISTORY.name])) + QKeySequence.fromString(shortcuts[Shortcut.OPEN_TRANSCRIPT_EDITOR.name]) + ) + self.stop_transcription_action.setShortcut( + QKeySequence.fromString(shortcuts[Shortcut.STOP_TRANSCRIPTION.name]) + ) + self.clear_history_action.setShortcut( + QKeySequence.fromString(shortcuts[Shortcut.CLEAR_HISTORY.name]) + ) def on_record_action_triggered(self): - recording_transcriber_window = RecordingTranscriberWidget(self, flags=Qt.WindowType.Window) + recording_transcriber_window = RecordingTranscriberWidget( + self, flags=Qt.WindowType.Window + ) recording_transcriber_window.show() def set_stop_transcription_action_enabled(self, enabled: bool): @@ -490,7 +616,7 @@ class MainWindowToolbar(ToolBar): class MainWindow(QMainWindow): table_widget: TranscriptionTasksTableWidget - tasks: Dict[int, 'FileTranscriptionTask'] + tasks: Dict[int, "FileTranscriptionTask"] tasks_changed = pyqtSignal() openai_access_token: Optional[str] @@ -511,34 +637,49 @@ class MainWindow(QMainWindow): self.shortcuts = self.shortcut_settings.load() self.default_export_file_name = self.settings.value( Settings.Key.DEFAULT_EXPORT_FILE_NAME, - '{{ input_file_name }} ({{ task }}d on {{ date_time }})') + "{{ input_file_name }} ({{ task }}d on {{ date_time }})", + ) self.tasks = {} self.tasks_changed.connect(self.on_tasks_changed) self.toolbar = MainWindowToolbar(shortcuts=self.shortcuts, parent=self) - self.toolbar.new_transcription_action_triggered.connect(self.on_new_transcription_action_triggered) - self.toolbar.open_transcript_action_triggered.connect(self.open_transcript_viewer) - self.toolbar.clear_history_action_triggered.connect(self.on_clear_history_action_triggered) - self.toolbar.stop_transcription_action_triggered.connect(self.on_stop_transcription_action_triggered) + self.toolbar.new_transcription_action_triggered.connect( + self.on_new_transcription_action_triggered + ) + self.toolbar.open_transcript_action_triggered.connect( + self.open_transcript_viewer + ) + self.toolbar.clear_history_action_triggered.connect( + self.on_clear_history_action_triggered + ) + self.toolbar.stop_transcription_action_triggered.connect( + self.on_stop_transcription_action_triggered + ) self.addToolBar(self.toolbar) self.setUnifiedTitleAndToolBarOnMac(True) - self.menu_bar = MenuBar(shortcuts=self.shortcuts, - default_export_file_name=self.default_export_file_name, - parent=self) + self.menu_bar = MenuBar( + shortcuts=self.shortcuts, + default_export_file_name=self.default_export_file_name, + parent=self, + ) self.menu_bar.import_action_triggered.connect( - self.on_new_transcription_action_triggered) + self.on_new_transcription_action_triggered + ) self.menu_bar.shortcuts_changed.connect(self.on_shortcuts_changed) - self.menu_bar.openai_api_key_changed.connect(self.on_openai_access_token_changed) - self.menu_bar.default_export_file_name_changed.connect(self.default_export_file_name_changed) + self.menu_bar.openai_api_key_changed.connect( + self.on_openai_access_token_changed + ) + self.menu_bar.default_export_file_name_changed.connect( + self.default_export_file_name_changed + ) self.setMenuBar(self.menu_bar) self.table_widget = TranscriptionTasksTableWidget(self) self.table_widget.doubleClicked.connect(self.on_table_double_clicked) self.table_widget.return_clicked.connect(self.open_transcript_viewer) - self.table_widget.itemSelectionChanged.connect( - self.on_table_selection_changed) + self.table_widget.itemSelectionChanged.connect(self.on_table_selection_changed) self.setCentralWidget(self.table_widget) @@ -548,8 +689,7 @@ class MainWindow(QMainWindow): self.transcriber_worker = FileTranscriberQueueWorker() self.transcriber_worker.moveToThread(self.transcriber_thread) - self.transcriber_worker.task_updated.connect( - self.update_task_table_row) + self.transcriber_worker.task_updated.connect(self.update_task_table_row) self.transcriber_worker.completed.connect(self.transcriber_thread.quit) self.transcriber_thread.started.connect(self.transcriber_worker.run) @@ -569,11 +709,14 @@ class MainWindow(QMainWindow): file_paths = [url.toLocalFile() for url in event.mimeData().urls()] self.open_file_transcriber_widget(file_paths=file_paths) - def on_file_transcriber_triggered(self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions, str]): + def on_file_transcriber_triggered( + self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions, str] + ): transcription_options, file_transcription_options, model_path = options for file_path in file_transcription_options.file_paths: task = FileTranscriptionTask( - file_path, transcription_options, file_transcription_options, model_path) + file_path, transcription_options, file_transcription_options, model_path + ) self.add_task(task) def load_task(self, task: FileTranscriptionTask): @@ -586,8 +729,10 @@ class MainWindow(QMainWindow): @staticmethod def task_completed_or_errored(task: FileTranscriptionTask): - return task.status == FileTranscriptionTask.Status.COMPLETED or \ - task.status == FileTranscriptionTask.Status.FAILED + return ( + task.status == FileTranscriptionTask.Status.COMPLETED + or task.status == FileTranscriptionTask.Status.FAILED + ) def on_clear_history_action_triggered(self): selected_rows = self.table_widget.selectionModel().selectedRows() @@ -595,10 +740,17 @@ class MainWindow(QMainWindow): return reply = QMessageBox.question( - self, _('Clear History'), - _('Are you sure you want to delete the selected transcription(s)? This action cannot be undone.')) + self, + _("Clear History"), + _( + "Are you sure you want to delete the selected transcription(s)? This action cannot be undone." + ), + ) if reply == QMessageBox.StandardButton.Yes: - task_ids = [TranscriptionTasksTableWidget.find_task_id(selected_row) for selected_row in selected_rows] + task_ids = [ + TranscriptionTasksTableWidget.find_task_id(selected_row) + for selected_row in selected_rows + ] for task_id in task_ids: self.table_widget.clear_task(task_id) self.tasks.pop(task_id) @@ -617,20 +769,24 @@ class MainWindow(QMainWindow): def on_new_transcription_action_triggered(self): (file_paths, __) = QFileDialog.getOpenFileNames( - self, _('Select audio file'), '', SUPPORTED_OUTPUT_FORMATS) + self, _("Select audio file"), "", SUPPORTED_OUTPUT_FORMATS + ) if len(file_paths) == 0: return self.open_file_transcriber_widget(file_paths) def open_file_transcriber_widget(self, file_paths: List[str]): - file_transcriber_window = FileTranscriberWidget(file_paths=file_paths, - default_output_file_name=self.default_export_file_name, - parent=self, - flags=Qt.WindowType.Window) - file_transcriber_window.triggered.connect( - self.on_file_transcriber_triggered) - file_transcriber_window.openai_access_token_changed.connect(self.on_openai_access_token_changed) + file_transcriber_window = FileTranscriberWidget( + file_paths=file_paths, + default_output_file_name=self.default_export_file_name, + parent=self, + flags=Qt.WindowType.Window, + ) + file_transcriber_window.triggered.connect(self.on_file_transcriber_triggered) + file_transcriber_window.openai_access_token_changed.connect( + self.on_openai_access_token_changed + ) file_transcriber_window.show() @staticmethod @@ -639,7 +795,9 @@ class MainWindow(QMainWindow): def default_export_file_name_changed(self, default_export_file_name: str): self.default_export_file_name = default_export_file_name - self.settings.set_value(Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name) + self.settings.set_value( + Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name + ) def open_transcript_viewer(self): selected_rows = self.table_widget.selectionModel().selectedRows() @@ -648,29 +806,49 @@ class MainWindow(QMainWindow): self.open_transcription_viewer(task_id) def on_table_selection_changed(self): - self.toolbar.set_open_transcript_action_enabled(self.should_enable_open_transcript_action()) - self.toolbar.set_stop_transcription_action_enabled(self.should_enable_stop_transcription_action()) - self.toolbar.set_clear_history_action_enabled(self.should_enable_clear_history_action()) + self.toolbar.set_open_transcript_action_enabled( + self.should_enable_open_transcript_action() + ) + self.toolbar.set_stop_transcription_action_enabled( + self.should_enable_stop_transcription_action() + ) + self.toolbar.set_clear_history_action_enabled( + self.should_enable_clear_history_action() + ) def should_enable_open_transcript_action(self): return self.selected_tasks_have_status([FileTranscriptionTask.Status.COMPLETED]) def should_enable_stop_transcription_action(self): return self.selected_tasks_have_status( - [FileTranscriptionTask.Status.IN_PROGRESS, FileTranscriptionTask.Status.QUEUED]) + [ + FileTranscriptionTask.Status.IN_PROGRESS, + FileTranscriptionTask.Status.QUEUED, + ] + ) def should_enable_clear_history_action(self): return self.selected_tasks_have_status( - [FileTranscriptionTask.Status.COMPLETED, FileTranscriptionTask.Status.FAILED, - FileTranscriptionTask.Status.CANCELED]) + [ + FileTranscriptionTask.Status.COMPLETED, + FileTranscriptionTask.Status.FAILED, + FileTranscriptionTask.Status.CANCELED, + ] + ) def selected_tasks_have_status(self, statuses: List[FileTranscriptionTask.Status]): selected_rows = self.table_widget.selectionModel().selectedRows() if len(selected_rows) == 0: return False return all( - [self.tasks[TranscriptionTasksTableWidget.find_task_id(selected_row)].status in statuses for selected_row in - selected_rows]) + [ + self.tasks[ + TranscriptionTasksTableWidget.find_task_id(selected_row) + ].status + in statuses + for selected_row in selected_rows + ] + ) def on_table_double_clicked(self, index: QModelIndex): task_id = TranscriptionTasksTableWidget.find_task_id(index) @@ -682,7 +860,8 @@ class MainWindow(QMainWindow): return transcription_viewer_widget = TranscriptionViewerWidget( - transcription_task=task, parent=self, flags=Qt.WindowType.Window) + transcription_task=task, parent=self, flags=Qt.WindowType.Window + ) transcription_viewer_widget.task_changed.connect(self.on_tasks_changed) transcription_viewer_widget.show() @@ -692,8 +871,10 @@ class MainWindow(QMainWindow): def load_tasks_from_cache(self): tasks = self.tasks_cache.load() for task in tasks: - if task.status == FileTranscriptionTask.Status.QUEUED or \ - task.status == FileTranscriptionTask.Status.IN_PROGRESS: + if ( + task.status == FileTranscriptionTask.Status.QUEUED + or task.status == FileTranscriptionTask.Status.IN_PROGRESS + ): task.status = None self.transcriber_worker.add_task(task) else: @@ -703,9 +884,15 @@ class MainWindow(QMainWindow): self.tasks_cache.save(list(self.tasks.values())) def on_tasks_changed(self): - self.toolbar.set_open_transcript_action_enabled(self.should_enable_open_transcript_action()) - self.toolbar.set_stop_transcription_action_enabled(self.should_enable_stop_transcription_action()) - self.toolbar.set_clear_history_action_enabled(self.should_enable_clear_history_action()) + self.toolbar.set_open_transcript_action_enabled( + self.should_enable_open_transcript_action() + ) + self.toolbar.set_stop_transcription_action_enabled( + self.should_enable_stop_transcription_action() + ) + self.toolbar.set_clear_history_action_enabled( + self.should_enable_clear_history_action() + ) self.save_tasks_to_cache() def on_shortcuts_changed(self, shortcuts: dict): @@ -723,7 +910,6 @@ class MainWindow(QMainWindow): super().closeEvent(event) - class Application(QApplication): window: MainWindow diff --git a/buzz/locale.py b/buzz/locale.py index e47f83f8..0a844725 100644 --- a/buzz/locale.py +++ b/buzz/locale.py @@ -6,12 +6,12 @@ from PyQt6.QtCore import QLocale from buzz.assets import get_asset_path from buzz.settings.settings import APP_NAME -if 'LANG' not in os.environ: +if "LANG" not in os.environ: language = str(QLocale().uiLanguages()[0]).replace("-", "_") - os.environ['LANG'] = language + os.environ["LANG"] = language -locale_dir = get_asset_path('locale') -gettext.bindtextdomain('buzz', locale_dir) +locale_dir = get_asset_path("locale") +gettext.bindtextdomain("buzz", locale_dir) translate = gettext.translation(APP_NAME, locale_dir, fallback=True) diff --git a/buzz/model_loader.py b/buzz/model_loader.py index d4597165..6f7b4722 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -20,11 +20,11 @@ from tqdm.auto import tqdm class WhisperModelSize(str, enum.Enum): - TINY = 'tiny' - BASE = 'base' - SMALL = 'small' - MEDIUM = 'medium' - LARGE = 'large' + TINY = "tiny" + BASE = "base" + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" def to_faster_whisper_model_size(self) -> str: if self == WhisperModelSize.LARGE: @@ -33,11 +33,11 @@ class WhisperModelSize(str, enum.Enum): class ModelType(enum.Enum): - WHISPER = 'Whisper' - WHISPER_CPP = 'Whisper.cpp' - HUGGING_FACE = 'Hugging Face' - FASTER_WHISPER = 'Faster Whisper' - OPEN_AI_WHISPER_API = 'OpenAI Whisper API' + WHISPER = "Whisper" + WHISPER_CPP = "Whisper.cpp" + HUGGING_FACE = "Hugging Face" + FASTER_WHISPER = "Faster Whisper" + OPEN_AI_WHISPER_API = "OpenAI Whisper API" @dataclass() @@ -47,9 +47,10 @@ class TranscriptionModel: hugging_face_model_id: Optional[str] = None def is_deletable(self): - return ((self.model_type == ModelType.WHISPER or - self.model_type == ModelType.WHISPER_CPP) and - self.get_local_model_path() is not None) + return ( + self.model_type == ModelType.WHISPER + or self.model_type == ModelType.WHISPER_CPP + ) and self.get_local_model_path() is not None def open_file_location(self): model_path = self.get_local_model_path() @@ -84,18 +85,20 @@ class TranscriptionModel: if self.model_type == ModelType.FASTER_WHISPER: try: - return download_faster_whisper_model(size=self.whisper_model_size.value, - local_files_only=True) + return download_faster_whisper_model( + size=self.whisper_model_size.value, local_files_only=True + ) except (ValueError, FileNotFoundError): return None if self.model_type == ModelType.OPEN_AI_WHISPER_API: - return '' + return "" if self.model_type == ModelType.HUGGING_FACE: try: - return huggingface_hub.snapshot_download(self.hugging_face_model_id, - local_files_only=True) + return huggingface_hub.snapshot_download( + self.hugging_face_model_id, local_files_only=True + ) except (ValueError, FileNotFoundError): return None @@ -103,36 +106,38 @@ class TranscriptionModel: WHISPER_CPP_MODELS_SHA256 = { - 'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21', - 'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe', - 'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b', - 'medium': '6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208', - 'large': '9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487' + "tiny": "be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21", + "base": "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe", + "small": "1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b", + "medium": "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208", + "large": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487", } def get_hugging_face_file_url(author: str, repository_name: str, filename: str): - return f'https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}' + return f"https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}" def get_whisper_cpp_file_path(size: WhisperModelSize) -> str: - root_dir = user_cache_dir('Buzz') - return os.path.join(root_dir, f'ggml-model-whisper-{size.value}.bin') + root_dir = user_cache_dir("Buzz") + return os.path.join(root_dir, f"ggml-model-whisper-{size.value}.bin") def get_whisper_file_path(size: WhisperModelSize) -> str: - root_dir = os.getenv("XDG_CACHE_HOME", os.path.join( - os.path.expanduser("~"), ".cache", "whisper")) + root_dir = os.getenv( + "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper") + ) url = whisper._MODELS[size.value] return os.path.join(root_dir, os.path.basename(url)) -def download_faster_whisper_model(size: str, local_files_only=False, - tqdm_class: Optional[tqdm] = None): +def download_faster_whisper_model( + size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None +): if size not in faster_whisper.utils._MODELS: raise ValueError( - "Invalid model size '%s', expected one of: %s" % ( - size, ", ".join(faster_whisper.utils._MODELS)) + "Invalid model size '%s', expected one of: %s" + % (size, ", ".join(faster_whisper.utils._MODELS)) ) repo_id = "guillaumekln/faster-whisper-%s" % size @@ -144,9 +149,12 @@ def download_faster_whisper_model(size: str, local_files_only=False, "vocabulary.txt", ] - return huggingface_hub.snapshot_download(repo_id, allow_patterns=allow_patterns, - local_files_only=local_files_only, - tqdm_class=tqdm_class) + return huggingface_hub.snapshot_download( + repo_id, + allow_patterns=allow_patterns, + local_files_only=local_files_only, + tqdm_class=tqdm_class, + ) class ModelDownloader(QRunnable): @@ -165,22 +173,24 @@ class ModelDownloader(QRunnable): def run(self) -> None: if self.model.model_type == ModelType.WHISPER_CPP: model_name = self.model.whisper_model_size.value - url = get_hugging_face_file_url(author='ggerganov', - repository_name='whisper.cpp', - filename=f'ggml-{model_name}.bin') - file_path = get_whisper_cpp_file_path( - size=self.model.whisper_model_size) + url = get_hugging_face_file_url( + author="ggerganov", + repository_name="whisper.cpp", + filename=f"ggml-{model_name}.bin", + ) + file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size) expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name] - return self.download_model_to_path(url=url, file_path=file_path, - expected_sha256=expected_sha256) + return self.download_model_to_path( + url=url, file_path=file_path, expected_sha256=expected_sha256 + ) if self.model.model_type == ModelType.WHISPER: url = whisper._MODELS[self.model.whisper_model_size.value] - file_path = get_whisper_file_path( - size=self.model.whisper_model_size) - expected_sha256 = url.split('/')[-2] - return self.download_model_to_path(url=url, file_path=file_path, - expected_sha256=expected_sha256) + file_path = get_whisper_file_path(size=self.model.whisper_model_size) + expected_sha256 = url.split("/")[-2] + return self.download_model_to_path( + url=url, file_path=file_path, expected_sha256=expected_sha256 + ) progress = self.signals.progress @@ -197,44 +207,47 @@ class ModelDownloader(QRunnable): if self.model.model_type == ModelType.FASTER_WHISPER: model_path = download_faster_whisper_model( size=self.model.whisper_model_size.to_faster_whisper_model_size(), - tqdm_class=_tqdm) + tqdm_class=_tqdm, + ) self.signals.finished.emit(model_path) return if self.model.model_type == ModelType.HUGGING_FACE: model_path = huggingface_hub.snapshot_download( - self.model.hugging_face_model_id, tqdm_class=_tqdm) + self.model.hugging_face_model_id, tqdm_class=_tqdm + ) self.signals.finished.emit(model_path) return if self.model.model_type == ModelType.OPEN_AI_WHISPER_API: - self.signals.finished.emit('') + self.signals.finished.emit("") return raise Exception("Invalid model type: " + self.model.model_type.value) - def download_model_to_path(self, url: str, file_path: str, - expected_sha256: Optional[str]): + def download_model_to_path( + self, url: str, file_path: str, expected_sha256: Optional[str] + ): try: downloaded = self.download_model(url, file_path, expected_sha256) if downloaded: self.signals.finished.emit(file_path) except requests.RequestException: - self.signals.error.emit('A connection error occurred') - logging.exception('') + self.signals.error.emit("A connection error occurred") + logging.exception("") except Exception as exc: self.signals.error.emit(str(exc)) logging.exception(exc) - def download_model(self, url: str, file_path: str, - expected_sha256: Optional[str]) -> bool: - logging.debug(f'Downloading model from {url} to {file_path}') + def download_model( + self, url: str, file_path: str, expected_sha256: Optional[str] + ) -> bool: + logging.debug(f"Downloading model from {url} to {file_path}") os.makedirs(os.path.dirname(file_path), exist_ok=True) if os.path.exists(file_path) and not os.path.isfile(file_path): - raise RuntimeError( - f"{file_path} exists and is not a regular file") + raise RuntimeError(f"{file_path} exists and is not a regular file") if os.path.isfile(file_path): if expected_sha256 is None: @@ -246,17 +259,19 @@ class ModelDownloader(QRunnable): return True else: warnings.warn( - f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file") + f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file" + ) tmp_file = tempfile.mktemp() - logging.debug('Downloading to temporary file = %s', tmp_file) + logging.debug("Downloading to temporary file = %s", tmp_file) # Downloads the model using the requests module instead of urllib to # use the certs from certifi when the app is running in frozen mode - with requests.get(url, stream=True, timeout=15) as source, open(tmp_file, - 'wb') as output: + with requests.get(url, stream=True, timeout=15) as source, open( + tmp_file, "wb" + ) as output: source.raise_for_status() - total_size = float(source.headers.get('Content-Length', 0)) + total_size = float(source.headers.get("Content-Length", 0)) current = 0.0 self.signals.progress.emit((current, total_size)) for chunk in source.iter_content(chunk_size=8192): @@ -271,13 +286,14 @@ class ModelDownloader(QRunnable): if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: raise RuntimeError( "Model has been downloaded but the SHA256 checksum does not match. Please retry loading the " - "model.") + "model." + ) - logging.debug('Downloaded model') + logging.debug("Downloaded model") # https://github.com/chidiwilliams/buzz/issues/454 shutil.move(tmp_file, file_path) - logging.debug('Moved file from %s to %s', tmp_file, file_path) + logging.debug("Moved file from %s to %s", tmp_file, file_path) return True def cancel(self): diff --git a/buzz/paths.py b/buzz/paths.py index 45cbe27a..9dc7552c 100644 --- a/buzz/paths.py +++ b/buzz/paths.py @@ -7,4 +7,4 @@ def file_path_as_title(file_path: str): def file_paths_as_title(file_paths: List[str]): - return ', '.join([file_path_as_title(path) for path in file_paths]) + return ", ".join([file_path_as_title(path) for path in file_paths]) diff --git a/buzz/recording.py b/buzz/recording.py index d54f3be0..aefd9513 100644 --- a/buzz/recording.py +++ b/buzz/recording.py @@ -10,19 +10,25 @@ class RecordingAmplitudeListener(QObject): stream: Optional[sounddevice.InputStream] = None amplitude_changed = pyqtSignal(float) - def __init__(self, input_device_index: Optional[int] = None, - parent: Optional[QObject] = None, - ): + def __init__( + self, + input_device_index: Optional[int] = None, + parent: Optional[QObject] = None, + ): super().__init__(parent) self.input_device_index = input_device_index def start_recording(self): try: - self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32', - channels=1, callback=self.stream_callback) + self.stream = sounddevice.InputStream( + device=self.input_device_index, + dtype="float32", + channels=1, + callback=self.stream_callback, + ) self.stream.start() except sounddevice.PortAudioError: - logging.exception('') + logging.exception("") def stop_recording(self): if self.stream is not None: @@ -31,5 +37,5 @@ class RecordingAmplitudeListener(QObject): def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status): chunk = in_data.ravel() - amplitude = np.sqrt(np.mean(chunk ** 2)) # root-mean-square + amplitude = np.sqrt(np.mean(chunk**2)) # root-mean-square self.amplitude_changed.emit(amplitude) diff --git a/buzz/recording_transcriber.py b/buzz/recording_transcriber.py index e4d2d842..917d7672 100644 --- a/buzz/recording_transcriber.py +++ b/buzz/recording_transcriber.py @@ -22,9 +22,14 @@ class RecordingTranscriber(QObject): is_running = False MAX_QUEUE_SIZE = 10 - def __init__(self, transcription_options: TranscriptionOptions, - input_device_index: Optional[int], sample_rate: int, model_path: str, - parent: Optional[QObject] = None) -> None: + def __init__( + self, + transcription_options: TranscriptionOptions, + input_device_index: Optional[int], + sample_rate: int, + model_path: str, + parent: Optional[QObject] = None, + ) -> None: super().__init__(parent) self.transcription_options = transcription_options self.current_stream = None @@ -49,60 +54,91 @@ class RecordingTranscriber(QObject): initial_prompt = self.transcription_options.initial_prompt - logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s', - self.transcription_options, model_path, self.sample_rate, self.input_device_index) + logging.debug( + "Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s", + self.transcription_options, + model_path, + self.sample_rate, + self.input_device_index, + ) self.is_running = True try: - with sounddevice.InputStream(samplerate=self.sample_rate, - device=self.input_device_index, dtype="float32", - channels=1, callback=self.stream_callback): + with sounddevice.InputStream( + samplerate=self.sample_rate, + device=self.input_device_index, + dtype="float32", + channels=1, + callback=self.stream_callback, + ): while self.is_running: self.mutex.acquire() if self.queue.size >= self.n_batch_samples: - samples = self.queue[:self.n_batch_samples] - self.queue = self.queue[self.n_batch_samples:] + samples = self.queue[: self.n_batch_samples] + self.queue = self.queue[self.n_batch_samples :] self.mutex.release() - logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s', - samples.size, self.queue.size, self.amplitude(samples)) + logging.debug( + "Processing next frame, sample size = %s, queue size = %s, amplitude = %s", + samples.size, + self.queue.size, + self.amplitude(samples), + ) time_started = datetime.datetime.now() - if self.transcription_options.model.model_type == ModelType.WHISPER: + if ( + self.transcription_options.model.model_type + == ModelType.WHISPER + ): assert isinstance(model, whisper.Whisper) result = model.transcribe( - audio=samples, language=self.transcription_options.language, + audio=samples, + language=self.transcription_options.language, task=self.transcription_options.task.value, initial_prompt=initial_prompt, - temperature=self.transcription_options.temperature) - elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP: + temperature=self.transcription_options.temperature, + ) + elif ( + self.transcription_options.model.model_type + == ModelType.WHISPER_CPP + ): assert isinstance(model, WhisperCpp) result = model.transcribe( audio=samples, params=whisper_cpp_params( language=self.transcription_options.language - if self.transcription_options.language is not None else 'en', - task=self.transcription_options.task.value, word_level_timings=False)) + if self.transcription_options.language is not None + else "en", + task=self.transcription_options.task.value, + word_level_timings=False, + ), + ) else: assert isinstance(model, TransformersWhisper) - result = model.transcribe(audio=samples, - language=self.transcription_options.language - if self.transcription_options.language is not None else 'en', - task=self.transcription_options.task.value) + result = model.transcribe( + audio=samples, + language=self.transcription_options.language + if self.transcription_options.language is not None + else "en", + task=self.transcription_options.task.value, + ) - next_text: str = result.get('text') + next_text: str = result.get("text") # Update initial prompt between successive recording chunks initial_prompt += next_text - logging.debug('Received next result, length = %s, time taken = %s', - len(next_text), datetime.datetime.now() - time_started) + logging.debug( + "Received next result, length = %s, time taken = %s", + len(next_text), + datetime.datetime.now() - time_started, + ) self.transcription.emit(next_text) else: self.mutex.release() except PortAudioError as exc: self.error.emit(str(exc)) - logging.exception('') + logging.exception("") return self.finished.emit() @@ -116,12 +152,13 @@ class RecordingTranscriber(QObject): whisper_sample_rate = whisper.audio.SAMPLE_RATE try: sounddevice.check_input_settings( - device=device_id, samplerate=whisper_sample_rate) + device=device_id, samplerate=whisper_sample_rate + ) return whisper_sample_rate except PortAudioError: device_info = sounddevice.query_devices(device=device_id) if isinstance(device_info, dict): - return int(device_info.get('default_samplerate', whisper_sample_rate)) + return int(device_info.get("default_samplerate", whisper_sample_rate)) return whisper_sample_rate def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status): diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index 28614573..4ee430ae 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -3,7 +3,7 @@ import typing from PyQt6.QtCore import QSettings -APP_NAME = 'Buzz' +APP_NAME = "Buzz" class Settings: @@ -11,32 +11,38 @@ class Settings: self.settings = QSettings(APP_NAME) class Key(enum.Enum): - RECORDING_TRANSCRIBER_TASK = 'recording-transcriber/task' - RECORDING_TRANSCRIBER_MODEL = 'recording-transcriber/model' - RECORDING_TRANSCRIBER_LANGUAGE = 'recording-transcriber/language' - RECORDING_TRANSCRIBER_TEMPERATURE = 'recording-transcriber/temperature' - RECORDING_TRANSCRIBER_INITIAL_PROMPT = 'recording-transcriber/initial-prompt' + RECORDING_TRANSCRIBER_TASK = "recording-transcriber/task" + RECORDING_TRANSCRIBER_MODEL = "recording-transcriber/model" + RECORDING_TRANSCRIBER_LANGUAGE = "recording-transcriber/language" + RECORDING_TRANSCRIBER_TEMPERATURE = "recording-transcriber/temperature" + RECORDING_TRANSCRIBER_INITIAL_PROMPT = "recording-transcriber/initial-prompt" - FILE_TRANSCRIBER_TASK = 'file-transcriber/task' - FILE_TRANSCRIBER_MODEL = 'file-transcriber/model' - FILE_TRANSCRIBER_LANGUAGE = 'file-transcriber/language' - FILE_TRANSCRIBER_TEMPERATURE = 'file-transcriber/temperature' - FILE_TRANSCRIBER_INITIAL_PROMPT = 'file-transcriber/initial-prompt' - FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = 'file-transcriber/word-level-timings' - FILE_TRANSCRIBER_EXPORT_FORMATS = 'file-transcriber/export-formats' + FILE_TRANSCRIBER_TASK = "file-transcriber/task" + FILE_TRANSCRIBER_MODEL = "file-transcriber/model" + FILE_TRANSCRIBER_LANGUAGE = "file-transcriber/language" + FILE_TRANSCRIBER_TEMPERATURE = "file-transcriber/temperature" + FILE_TRANSCRIBER_INITIAL_PROMPT = "file-transcriber/initial-prompt" + FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = "file-transcriber/word-level-timings" + FILE_TRANSCRIBER_EXPORT_FORMATS = "file-transcriber/export-formats" - DEFAULT_EXPORT_FILE_NAME = 'transcriber/default-export-file-name' + DEFAULT_EXPORT_FILE_NAME = "transcriber/default-export-file-name" - SHORTCUTS = 'shortcuts' + SHORTCUTS = "shortcuts" def set_value(self, key: Key, value: typing.Any) -> None: self.settings.setValue(key.value, value) - def value(self, key: Key, default_value: typing.Any, - value_type: typing.Optional[type] = None) -> typing.Any: - return self.settings.value(key.value, default_value, - value_type if value_type is not None else type( - default_value)) + def value( + self, + key: Key, + default_value: typing.Any, + value_type: typing.Optional[type] = None, + ) -> typing.Any: + return self.settings.value( + key.value, + default_value, + value_type if value_type is not None else type(default_value), + ) def clear(self): self.settings.clear() diff --git a/buzz/settings/shortcut.py b/buzz/settings/shortcut.py index 719aba7b..fbd5bd9a 100644 --- a/buzz/settings/shortcut.py +++ b/buzz/settings/shortcut.py @@ -13,13 +13,13 @@ class Shortcut(str, enum.Enum): obj.description = description return obj - OPEN_RECORD_WINDOW = ('Ctrl+R', "Open Record Window") - OPEN_IMPORT_WINDOW = ('Ctrl+O', "Import File") - OPEN_PREFERENCES_WINDOW = ('Ctrl+,', 'Open Preferences Window') + OPEN_RECORD_WINDOW = ("Ctrl+R", "Open Record Window") + OPEN_IMPORT_WINDOW = ("Ctrl+O", "Import File") + OPEN_PREFERENCES_WINDOW = ("Ctrl+,", "Open Preferences Window") - OPEN_TRANSCRIPT_EDITOR = ('Ctrl+E', "Open Transcript Viewer") - CLEAR_HISTORY = ('Ctrl+S', "Clear History") - STOP_TRANSCRIPTION = ('Ctrl+X', "Cancel Transcription") + OPEN_TRANSCRIPT_EDITOR = ("Ctrl+E", "Open Transcript Viewer") + CLEAR_HISTORY = ("Ctrl+S", "Clear History") + STOP_TRANSCRIPTION = ("Ctrl+X", "Cancel Transcription") @staticmethod def get_default_shortcuts() -> typing.Dict[str, str]: diff --git a/buzz/settings/shortcut_settings.py b/buzz/settings/shortcut_settings.py index 3e200d8f..7465723e 100644 --- a/buzz/settings/shortcut_settings.py +++ b/buzz/settings/shortcut_settings.py @@ -10,7 +10,9 @@ class ShortcutSettings: def load(self) -> typing.Dict[str, str]: shortcuts = Shortcut.get_default_shortcuts() - custom_shortcuts: typing.Dict[str, str] = self.settings.value(Settings.Key.SHORTCUTS, {}) + custom_shortcuts: typing.Dict[str, str] = self.settings.value( + Settings.Key.SHORTCUTS, {} + ) for shortcut_name in custom_shortcuts: shortcuts[shortcut_name] = custom_shortcuts[shortcut_name] return shortcuts diff --git a/buzz/store/keyring_store.py b/buzz/store/keyring_store.py index b0795f8c..21365d88 100644 --- a/buzz/store/keyring_store.py +++ b/buzz/store/keyring_store.py @@ -9,20 +9,20 @@ from buzz.settings.settings import APP_NAME class KeyringStore: class Key(enum.Enum): - OPENAI_API_KEY = 'OpenAI API key' + OPENAI_API_KEY = "OpenAI API key" def get_password(self, key: Key) -> str: try: password = keyring.get_password(APP_NAME, username=key.value) if password is None: - return '' + return "" return password except (KeyringLocked, KeyringError) as exc: - logging.error('Unable to read from keyring: %s', exc) - return '' + logging.error("Unable to read from keyring: %s", exc) + return "" def set_password(self, username: Key, password: str) -> None: try: keyring.set_password(APP_NAME, username.value, password) except (KeyringLocked, PasswordSetError) as exc: - logging.error('Unable to write to keyring: %s', exc) + logging.error("Unable to write to keyring: %s", exc) diff --git a/buzz/transcriber.py b/buzz/transcriber.py index 48316a38..ef3d3394 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -38,7 +38,7 @@ try: LOADED_WHISPER_DLL = True except ImportError: - logging.exception('') + logging.exception("") DEFAULT_WHISPER_TEMPERATURE = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) @@ -65,27 +65,28 @@ class TranscriptionOptions: model: TranscriptionModel = field(default_factory=TranscriptionModel) word_level_timings: bool = False temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE - initial_prompt: str = '' - openai_access_token: str = field(default='', - metadata=config(exclude=Exclude.ALWAYS)) + initial_prompt: str = "" + openai_access_token: str = field( + default="", metadata=config(exclude=Exclude.ALWAYS) + ) @dataclass() class FileTranscriptionOptions: file_paths: List[str] - output_formats: Set['OutputFormat'] = field(default_factory=set) - default_output_file_name: str = '' + output_formats: Set["OutputFormat"] = field(default_factory=set) + default_output_file_name: str = "" @dataclass_json @dataclass class FileTranscriptionTask: class Status(enum.Enum): - QUEUED = 'queued' - IN_PROGRESS = 'in_progress' - COMPLETED = 'completed' - FAILED = 'failed' - CANCELED = 'canceled' + QUEUED = "queued" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + CANCELED = "canceled" file_path: str transcription_options: TranscriptionOptions @@ -102,9 +103,9 @@ class FileTranscriptionTask: class OutputFormat(enum.Enum): - TXT = 'txt' - SRT = 'srt' - VTT = 'vtt' + TXT = "txt" + SRT = "srt" + VTT = "vtt" class FileTranscriber(QObject): @@ -113,8 +114,7 @@ class FileTranscriber(QObject): completed = pyqtSignal(list) # List[Segment] error = pyqtSignal(Exception) - def __init__(self, task: FileTranscriptionTask, - parent: Optional['QObject'] = None): + def __init__(self, task: FileTranscriptionTask, parent: Optional["QObject"] = None): super().__init__(parent) self.transcription_task = task @@ -128,12 +128,16 @@ class FileTranscriber(QObject): self.completed.emit(segments) - for output_format in self.transcription_task.file_transcription_options.output_formats: - default_path = get_default_output_file_path(task=self.transcription_task, - output_format=output_format) + for ( + output_format + ) in self.transcription_task.file_transcription_options.output_formats: + default_path = get_default_output_file_path( + task=self.transcription_task, output_format=output_format + ) - write_output(path=default_path, segments=segments, - output_format=output_format) + write_output( + path=default_path, segments=segments, output_format=output_format + ) @abstractmethod def transcribe(self) -> List[Segment]: @@ -150,13 +154,14 @@ class Stopped(Exception): class WhisperCppFileTranscriber(FileTranscriber): duration_audio_ms = sys.maxsize # max int - state: 'WhisperCppFileTranscriber.State' + state: "WhisperCppFileTranscriber.State" class State: running = True - def __init__(self, task: FileTranscriptionTask, - parent: Optional['QObject'] = None) -> None: + def __init__( + self, task: FileTranscriptionTask, parent: Optional["QObject"] = None + ) -> None: super().__init__(task, parent) self.file_path = task.file_path @@ -171,24 +176,33 @@ class WhisperCppFileTranscriber(FileTranscriber): model_path = self.model_path logging.debug( - 'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, ' - 'word level timings = %s', - self.file_path, self.language, self.task, model_path, - self.word_level_timings) + "Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, " + "word level timings = %s", + self.file_path, + self.language, + self.task, + model_path, + self.word_level_timings, + ) audio = whisper.audio.load_audio(self.file_path) self.duration_audio_ms = len(audio) * 1000 / whisper.audio.SAMPLE_RATE whisper_params = whisper_cpp_params( - language=self.language if self.language is not None else '', task=self.task, - word_level_timings=self.word_level_timings) + language=self.language if self.language is not None else "", + task=self.task, + word_level_timings=self.word_level_timings, + ) whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p( - id(self.state)) - whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback( - self.encoder_begin_callback) + id(self.state) + ) + whisper_params.encoder_begin_callback = ( + whisper_cpp.whisper_encoder_begin_callback(self.encoder_begin_callback) + ) whisper_params.new_segment_callback_user_data = ctypes.c_void_p(id(self.state)) whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback( - self.new_segment_callback) + self.new_segment_callback + ) model = WhisperCpp(model=model_path) result = model.transcribe(audio=self.file_path, params=whisper_params) @@ -197,7 +211,7 @@ class WhisperCppFileTranscriber(FileTranscriber): raise Stopped self.state.running = False - return result['segments'] + return result["segments"] def new_segment_callback(self, ctx, _state, _n_new, user_data): n_segments = whisper_cpp.whisper_full_n_segments(ctx) @@ -205,15 +219,17 @@ class WhisperCppFileTranscriber(FileTranscriber): # t1 seems to sometimes be larger than the duration when the # audio ends in silence. Trim to fix the displayed progress. progress = min(t1 * 10, self.duration_audio_ms) - state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, - ctypes.py_object).value + state: WhisperCppFileTranscriber.State = ctypes.cast( + user_data, ctypes.py_object + ).value if state.running: self.progress.emit((progress, self.duration_audio_ms)) @staticmethod def encoder_begin_callback(_ctx, _state, user_data): - state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, - ctypes.py_object).value + state: WhisperCppFileTranscriber.State = ctypes.cast( + user_data, ctypes.py_object + ).value return state.running == 1 def stop(self): @@ -221,18 +237,19 @@ class WhisperCppFileTranscriber(FileTranscriber): class OpenAIWhisperAPIFileTranscriber(FileTranscriber): - def __init__(self, task: FileTranscriptionTask, parent: Optional['QObject'] = None): + def __init__(self, task: FileTranscriptionTask, parent: Optional["QObject"] = None): super().__init__(task=task, parent=parent) self.file_path = task.file_path self.task = task.transcription_options.task def transcribe(self) -> List[Segment]: logging.debug( - 'Starting OpenAI Whisper API file transcription, file path = %s, task = %s', + "Starting OpenAI Whisper API file transcription, file path = %s, task = %s", self.file_path, - self.task) + self.task, + ) - wav_file = tempfile.mktemp() + '.wav' + wav_file = tempfile.mktemp() + ".wav" ( ffmpeg.input(self.file_path) .output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE) @@ -241,22 +258,30 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber): # TODO: Check if file size is more than 25MB (2.5 minutes), then chunk audio_file = open(wav_file, "rb") - openai.api_key = self.transcription_task.transcription_options.openai_access_token + openai.api_key = ( + self.transcription_task.transcription_options.openai_access_token + ) language = self.transcription_task.transcription_options.language response_format = "verbose_json" if self.transcription_task.transcription_options.task == Task.TRANSLATE: - transcript = openai.Audio.translate("whisper-1", audio_file, - response_format=response_format, - language=language) + transcript = openai.Audio.translate( + "whisper-1", + audio_file, + response_format=response_format, + language=language, + ) else: - transcript = openai.Audio.transcribe("whisper-1", audio_file, - response_format=response_format, - language=language) + transcript = openai.Audio.transcribe( + "whisper-1", + audio_file, + response_format=response_format, + language=language, + ) segments = [ - Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for - segment in - transcript["segments"]] + Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) + for segment in transcript["segments"] + ] return segments def stop(self): @@ -265,15 +290,16 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber): class WhisperFileTranscriber(FileTranscriber): """WhisperFileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file - using the default program for opening txt files. """ + using the default program for opening txt files.""" current_process: multiprocessing.Process running = False read_line_thread: Optional[Thread] = None - READ_LINE_THREAD_STOP_TOKEN = '--STOP--' + READ_LINE_THREAD_STOP_TOKEN = "--STOP--" - def __init__(self, task: FileTranscriptionTask, - parent: Optional['QObject'] = None) -> None: + def __init__( + self, task: FileTranscriptionTask, parent: Optional["QObject"] = None + ) -> None: super().__init__(task, parent) self.segments = [] self.started_process = False @@ -282,19 +308,19 @@ class WhisperFileTranscriber(FileTranscriber): def transcribe(self) -> List[Segment]: time_started = datetime.datetime.now() logging.debug( - 'Starting whisper file transcription, task = %s', self.transcription_task) + "Starting whisper file transcription, task = %s", self.transcription_task + ) recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False) - self.current_process = multiprocessing.Process(target=self.transcribe_whisper, - args=(send_pipe, - self.transcription_task)) + self.current_process = multiprocessing.Process( + target=self.transcribe_whisper, args=(send_pipe, self.transcription_task) + ) if not self.stopped: self.current_process.start() self.started_process = True - self.read_line_thread = Thread( - target=self.read_line, args=(recv_pipe,)) + self.read_line_thread = Thread(target=self.read_line, args=(recv_pipe,)) self.read_line_thread.start() self.current_process.join() @@ -305,76 +331,96 @@ class WhisperFileTranscriber(FileTranscriber): self.read_line_thread.join() logging.debug( - 'whisper process completed with code = %s, time taken = %s, number of segments = %s', - self.current_process.exitcode, datetime.datetime.now() - time_started, - len(self.segments)) + "whisper process completed with code = %s, time taken = %s, number of segments = %s", + self.current_process.exitcode, + datetime.datetime.now() - time_started, + len(self.segments), + ) if self.current_process.exitcode != 0: - raise Exception('Unknown error') + raise Exception("Unknown error") return self.segments @classmethod - def transcribe_whisper(cls, stderr_conn: Connection, - task: FileTranscriptionTask) -> None: + def transcribe_whisper( + cls, stderr_conn: Connection, task: FileTranscriptionTask + ) -> None: with pipe_stderr(stderr_conn): if task.transcription_options.model.model_type == ModelType.HUGGING_FACE: segments = cls.transcribe_hugging_face(task) - elif task.transcription_options.model.model_type == ModelType.FASTER_WHISPER: + elif ( + task.transcription_options.model.model_type == ModelType.FASTER_WHISPER + ): segments = cls.transcribe_faster_whisper(task) elif task.transcription_options.model.model_type == ModelType.WHISPER: segments = cls.transcribe_openai_whisper(task) else: raise Exception( - f"Invalid model type: {task.transcription_options.model.model_type}") + f"Invalid model type: {task.transcription_options.model.model_type}" + ) - segments_json = json.dumps( - segments, ensure_ascii=True, default=vars) - sys.stderr.write(f'segments = {segments_json}\n') - sys.stderr.write( - WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + '\n') + segments_json = json.dumps(segments, ensure_ascii=True, default=vars) + sys.stderr.write(f"segments = {segments_json}\n") + sys.stderr.write(WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + "\n") @classmethod def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]: model = transformers_whisper.load_model(task.model_path) - language = task.transcription_options.language if task.transcription_options.language is not None else 'en' - result = model.transcribe(audio=task.file_path, language=language, - task=task.transcription_options.task.value, - verbose=False) + language = ( + task.transcription_options.language + if task.transcription_options.language is not None + else "en" + ) + result = model.transcribe( + audio=task.file_path, + language=language, + task=task.transcription_options.task.value, + verbose=False, + ) return [ Segment( - start=int(segment.get('start') * 1000), - end=int(segment.get('end') * 1000), - text=segment.get('text'), - ) for segment in result.get('segments')] + start=int(segment.get("start") * 1000), + end=int(segment.get("end") * 1000), + text=segment.get("text"), + ) + for segment in result.get("segments") + ] @classmethod def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]: model = faster_whisper.WhisperModel( - model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()) - whisper_segments, info = model.transcribe(audio=task.file_path, - language=task.transcription_options.language, - task=task.transcription_options.task.value, - temperature=task.transcription_options.temperature, - initial_prompt=task.transcription_options.initial_prompt, - word_timestamps=task.transcription_options.word_level_timings) + model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size() + ) + whisper_segments, info = model.transcribe( + audio=task.file_path, + language=task.transcription_options.language, + task=task.transcription_options.task.value, + temperature=task.transcription_options.temperature, + initial_prompt=task.transcription_options.initial_prompt, + word_timestamps=task.transcription_options.word_level_timings, + ) segments = [] - with tqdm.tqdm(total=round(info.duration, 2), unit=' seconds') as pbar: + with tqdm.tqdm(total=round(info.duration, 2), unit=" seconds") as pbar: for segment in list(whisper_segments): # Segment will contain words if word-level timings is True if segment.words: for word in segment.words: - segments.append(Segment( - start=int(word.start * 1000), - end=int(word.end * 1000), - text=word.word - )) + segments.append( + Segment( + start=int(word.start * 1000), + end=int(word.end * 1000), + text=word.word, + ) + ) else: - segments.append(Segment( - start=int(segment.start * 1000), - end=int(segment.end * 1000), - text=segment.text - )) + segments.append( + Segment( + start=int(segment.start * 1000), + end=int(segment.end * 1000), + text=segment.text, + ) + ) pbar.update(segment.end - segment.start) return segments @@ -386,28 +432,40 @@ class WhisperFileTranscriber(FileTranscriber): if task.transcription_options.word_level_timings: stable_whisper.modify_model(model) result = model.transcribe( - audio=task.file_path, language=task.transcription_options.language, + audio=task.file_path, + language=task.transcription_options.language, task=task.transcription_options.task.value, temperature=task.transcription_options.temperature, - initial_prompt=task.transcription_options.initial_prompt, pbar=True) + initial_prompt=task.transcription_options.initial_prompt, + pbar=True, + ) segments = stable_whisper.group_word_timestamps(result) - return [Segment( - start=int(segment.get('start') * 1000), - end=int(segment.get('end') * 1000), - text=segment.get('text'), - ) for segment in segments] + return [ + Segment( + start=int(segment.get("start") * 1000), + end=int(segment.get("end") * 1000), + text=segment.get("text"), + ) + for segment in segments + ] result = model.transcribe( - audio=task.file_path, language=task.transcription_options.language, + audio=task.file_path, + language=task.transcription_options.language, task=task.transcription_options.task.value, temperature=task.transcription_options.temperature, - initial_prompt=task.transcription_options.initial_prompt, verbose=False) - segments = result.get('segments') - return [Segment( - start=int(segment.get('start') * 1000), - end=int(segment.get('end') * 1000), - text=segment.get('text'), - ) for segment in segments] + initial_prompt=task.transcription_options.initial_prompt, + verbose=False, + ) + segments = result.get("segments") + return [ + Segment( + start=int(segment.get("start") * 1000), + end=int(segment.get("end") * 1000), + text=segment.get("text"), + ) + for segment in segments + ] def stop(self): self.stopped = True @@ -424,102 +482,119 @@ class WhisperFileTranscriber(FileTranscriber): if line == self.READ_LINE_THREAD_STOP_TOKEN: return - if line.startswith('segments = '): + if line.startswith("segments = "): segments_dict = json.loads(line[11:]) - segments = [Segment( - start=segment.get('start'), - end=segment.get('end'), - text=segment.get('text'), - ) for segment in segments_dict] + segments = [ + Segment( + start=segment.get("start"), + end=segment.get("end"), + text=segment.get("text"), + ) + for segment in segments_dict + ] self.segments = segments else: try: - progress = int(line.split('|')[0].strip().strip('%')) + progress = int(line.split("|")[0].strip().strip("%")) self.progress.emit((progress, 100)) except ValueError: - logging.debug('whisper (stderr): %s', line) + logging.debug("whisper (stderr): %s", line) continue def write_output(path: str, segments: List[Segment], output_format: OutputFormat): logging.debug( - 'Writing transcription output, path = %s, output format = %s, number of segments = %s', - path, output_format, - len(segments)) + "Writing transcription output, path = %s, output format = %s, number of segments = %s", + path, + output_format, + len(segments), + ) - with open(path, 'w', encoding='utf-8') as file: + with open(path, "w", encoding="utf-8") as file: if output_format == OutputFormat.TXT: - for (i, segment) in enumerate(segments): + for i, segment in enumerate(segments): file.write(segment.text) - file.write('\n') + file.write("\n") elif output_format == OutputFormat.VTT: - file.write('WEBVTT\n\n') + file.write("WEBVTT\n\n") for segment in segments: file.write( - f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n') - file.write(f'{segment.text}\n\n') + f"{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n" + ) + file.write(f"{segment.text}\n\n") elif output_format == OutputFormat.SRT: - for (i, segment) in enumerate(segments): - file.write(f'{i + 1}\n') + for i, segment in enumerate(segments): + file.write(f"{i + 1}\n") file.write( - f'{to_timestamp(segment.start, ms_separator=",")} --> {to_timestamp(segment.end, ms_separator=",")}\n') - file.write(f'{segment.text}\n\n') + f'{to_timestamp(segment.start, ms_separator=",")} --> {to_timestamp(segment.end, ms_separator=",")}\n' + ) + file.write(f"{segment.text}\n\n") - logging.debug('Written transcription output') + logging.debug("Written transcription output") def segments_to_text(segments: List[Segment]) -> str: - result = '' - for (i, segment) in enumerate(segments): - result += f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n' - result += f'{segment.text}' + result = "" + for i, segment in enumerate(segments): + result += f"{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n" + result += f"{segment.text}" if i < len(segments) - 1: - result += '\n\n' + result += "\n\n" return result -def to_timestamp(ms: float, ms_separator='.') -> str: +def to_timestamp(ms: float, ms_separator=".") -> str: hr = int(ms / (1000 * 60 * 60)) ms = ms - hr * (1000 * 60 * 60) min = int(ms / (1000 * 60)) ms = ms - min * (1000 * 60) sec = int(ms / 1000) ms = int(ms - sec * 1000) - return f'{hr:02d}:{min:02d}:{sec:02d}{ms_separator}{ms:03d}' + return f"{hr:02d}:{min:02d}:{sec:02d}{ms_separator}{ms:03d}" -SUPPORTED_OUTPUT_FORMATS = 'Audio files (*.mp3 *.wav *.m4a *.ogg);;\ -Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)' +SUPPORTED_OUTPUT_FORMATS = "Audio files (*.mp3 *.wav *.m4a *.ogg);;\ +Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)" -def get_default_output_file_path(task: FileTranscriptionTask, - output_format: OutputFormat): +def get_default_output_file_path( + task: FileTranscriptionTask, output_format: OutputFormat +): input_file_name = os.path.splitext(task.file_path)[0] - date_time_now = datetime.datetime.now().strftime('%d-%b-%Y %H-%M-%S') - return (task.file_transcription_options.default_output_file_name - .replace('{{ input_file_name }}', input_file_name) - .replace('{{ task }}', task.transcription_options.task.value) - .replace('{{ language }}', task.transcription_options.language or '') - .replace('{{ model_type }}', - task.transcription_options.model.model_type.value) - .replace('{{ model_size }}', - task.transcription_options.model.whisper_model_size.value if - task.transcription_options.model.whisper_model_size is not None else - '') - .replace('{{ date_time }}', date_time_now) - + f".{output_format.value}") + date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S") + return ( + task.file_transcription_options.default_output_file_name.replace( + "{{ input_file_name }}", input_file_name + ) + .replace("{{ task }}", task.transcription_options.task.value) + .replace("{{ language }}", task.transcription_options.language or "") + .replace("{{ model_type }}", task.transcription_options.model.model_type.value) + .replace( + "{{ model_size }}", + task.transcription_options.model.whisper_model_size.value + if task.transcription_options.model.whisper_model_size is not None + else "", + ) + .replace("{{ date_time }}", date_time_now) + + f".{output_format.value}" + ) def whisper_cpp_params( - language: str, task: Task, word_level_timings: bool, - print_realtime=False, print_progress=False, ): + language: str, + task: Task, + word_level_timings: bool, + print_realtime=False, + print_progress=False, +): params = whisper_cpp.whisper_full_default_params( - whisper_cpp.WHISPER_SAMPLING_GREEDY) + whisper_cpp.WHISPER_SAMPLING_GREEDY + ) params.print_realtime = print_realtime params.print_progress = print_progress - params.language = whisper_cpp.String(language.encode('utf-8')) + params.language = whisper_cpp.String(language.encode("utf-8")) params.translate = task == Task.TRANSLATE params.max_len = ctypes.c_int(1) params.max_len = 1 if word_level_timings else 0 @@ -529,20 +604,20 @@ def whisper_cpp_params( class WhisperCpp: def __init__(self, model: str) -> None: - self.ctx = whisper_cpp.whisper_init_from_file(model.encode('utf-8')) + self.ctx = whisper_cpp.whisper_init_from_file(model.encode("utf-8")) def transcribe(self, audio: Union[np.ndarray, str], params: Any): if isinstance(audio, str): audio = whisper.audio.load_audio(audio) - logging.debug('Loaded audio with length = %s', len(audio)) + logging.debug("Loaded audio with length = %s", len(audio)) - whisper_cpp_audio = audio.ctypes.data_as( - ctypes.POINTER(ctypes.c_float)) + whisper_cpp_audio = audio.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) result = whisper_cpp.whisper_full( - self.ctx, params, whisper_cpp_audio, len(audio)) + self.ctx, params, whisper_cpp_audio, len(audio) + ) if result != 0: - raise Exception(f'Error from whisper.cpp: {result}') + raise Exception(f"Error from whisper.cpp: {result}") segments: List[Segment] = [] @@ -553,13 +628,17 @@ class WhisperCpp: t1 = whisper_cpp.whisper_full_get_segment_t1((self.ctx), i) segments.append( - Segment(start=t0 * 10, # centisecond to ms - end=t1 * 10, # centisecond to ms - text=txt.decode('utf-8'))) + Segment( + start=t0 * 10, # centisecond to ms + end=t1 * 10, # centisecond to ms + text=txt.decode("utf-8"), + ) + ) return { - 'segments': segments, - 'text': ''.join([segment.text for segment in segments])} + "segments": segments, + "text": "".join([segment.text for segment in segments]), + } def __del__(self): whisper_cpp.whisper_free(self.ctx) diff --git a/buzz/transformers_whisper.py b/buzz/transformers_whisper.py index eec710c8..15ee6d37 100644 --- a/buzz/transformers_whisper.py +++ b/buzz/transformers_whisper.py @@ -16,43 +16,65 @@ class TransformersWhisper: SAMPLE_RATE = whisper.audio.SAMPLE_RATE N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES - def __init__(self, processor: WhisperProcessor, model: WhisperForConditionalGeneration): + def __init__( + self, processor: WhisperProcessor, model: WhisperForConditionalGeneration + ): self.processor = processor self.model = model # Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and # timestamps are available. See: https://github.com/huggingface/transformers/issues/19887, # https://github.com/huggingface/transformers/pull/20620. - def transcribe(self, audio: Union[str, np.ndarray], language: str, task: str, verbose: Optional[bool] = None): + def transcribe( + self, + audio: Union[str, np.ndarray], + language: str, + task: str, + verbose: Optional[bool] = None, + ): if isinstance(audio, str): audio = whisper.load_audio(audio, sr=self.SAMPLE_RATE) - self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(task=task, language=language) + self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids( + task=task, language=language + ) segments = [] all_predicted_ids = [] num_samples = audio.size seek = 0 - with tqdm(total=num_samples, unit='samples', disable=verbose is not False) as progress_bar: + with tqdm( + total=num_samples, unit="samples", disable=verbose is not False + ) as progress_bar: while seek < num_samples: - chunk = audio[seek: seek + self.N_SAMPLES_IN_CHUNK] - input_features = self.processor(chunk, return_tensors="pt", - sampling_rate=self.SAMPLE_RATE).input_features + chunk = audio[seek : seek + self.N_SAMPLES_IN_CHUNK] + input_features = self.processor( + chunk, return_tensors="pt", sampling_rate=self.SAMPLE_RATE + ).input_features predicted_ids = self.model.generate(input_features) all_predicted_ids.extend(predicted_ids) - text: str = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] - if text.strip() != '': - segments.append({ - 'start': seek / self.SAMPLE_RATE, - 'end': min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) / self.SAMPLE_RATE, - 'text': text - }) + text: str = self.processor.batch_decode( + predicted_ids, skip_special_tokens=True + )[0] + if text.strip() != "": + segments.append( + { + "start": seek / self.SAMPLE_RATE, + "end": min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) + / self.SAMPLE_RATE, + "text": text, + } + ) - progress_bar.update(min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek) + progress_bar.update( + min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek + ) seek += self.N_SAMPLES_IN_CHUNK return { - 'text': self.processor.batch_decode(all_predicted_ids, skip_special_tokens=True)[0], - 'segments': segments + "text": self.processor.batch_decode( + all_predicted_ids, skip_special_tokens=True + )[0], + "segments": segments, } diff --git a/buzz/widgets/about_dialog.py b/buzz/widgets/about_dialog.py index a8e62b23..d4b3abbb 100644 --- a/buzz/widgets/about_dialog.py +++ b/buzz/widgets/about_dialog.py @@ -5,8 +5,15 @@ from PyQt6 import QtGui from PyQt6.QtCore import Qt, QUrl from PyQt6.QtGui import QIcon, QPixmap, QDesktopServices from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply -from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QLabel, QPushButton, \ - QDialogButtonBox, QMessageBox +from PyQt6.QtWidgets import ( + QDialog, + QWidget, + QVBoxLayout, + QLabel, + QPushButton, + QDialogButtonBox, + QMessageBox, +) from buzz.__version__ import VERSION from buzz.widgets.icon import BUZZ_ICON_PATH, BUZZ_LARGE_ICON_PATH @@ -15,11 +22,16 @@ from buzz.settings.settings import APP_NAME class AboutDialog(QDialog): - GITHUB_API_LATEST_RELEASE_URL = 'https://api.github.com/repos/chidiwilliams/buzz/releases/latest' - GITHUB_LATEST_RELEASE_URL = 'https://github.com/chidiwilliams/buzz/releases/latest' + GITHUB_API_LATEST_RELEASE_URL = ( + "https://api.github.com/repos/chidiwilliams/buzz/releases/latest" + ) + GITHUB_LATEST_RELEASE_URL = "https://github.com/chidiwilliams/buzz/releases/latest" - def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None, - parent: Optional[QWidget] = None) -> None: + def __init__( + self, + network_access_manager: Optional[QNetworkAccessManager] = None, + parent: Optional[QWidget] = None, + ) -> None: super().__init__(parent) self.setWindowIcon(QIcon(BUZZ_ICON_PATH)) @@ -35,28 +47,42 @@ class AboutDialog(QDialog): image_label = QLabel() pixmap = QPixmap(BUZZ_LARGE_ICON_PATH).scaled( - 80, 80, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + 80, + 80, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) image_label.setPixmap(pixmap) - image_label.setAlignment(Qt.AlignmentFlag( - Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter)) + image_label.setAlignment( + Qt.AlignmentFlag( + Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter + ) + ) buzz_label = QLabel(APP_NAME) - buzz_label.setAlignment(Qt.AlignmentFlag( - Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter)) + buzz_label.setAlignment( + Qt.AlignmentFlag( + Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter + ) + ) buzz_label_font = QtGui.QFont() buzz_label_font.setBold(True) buzz_label_font.setPointSize(20) buzz_label.setFont(buzz_label_font) version_label = QLabel(f"{_('Version')} {VERSION}") - version_label.setAlignment(Qt.AlignmentFlag( - Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter)) + version_label.setAlignment( + Qt.AlignmentFlag( + Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter + ) + ) - self.check_updates_button = QPushButton(_('Check for updates'), self) + self.check_updates_button = QPushButton(_("Check for updates"), self) self.check_updates_button.clicked.connect(self.on_click_check_for_updates) - button_box = QDialogButtonBox(QDialogButtonBox.StandardButton( - QDialogButtonBox.StandardButton.Close), self) + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Close), self + ) button_box.accepted.connect(self.accept) button_box.rejected.connect(self.reject) @@ -76,13 +102,13 @@ class AboutDialog(QDialog): def on_latest_release_reply(self, reply: QNetworkReply): if reply.error() == QNetworkReply.NetworkError.NoError: response = json.loads(reply.readAll().data()) - tag_name = response.get('name') + tag_name = response.get("name") if self.is_version_lower(VERSION, tag_name[1:]): QDesktopServices.openUrl(QUrl(self.GITHUB_LATEST_RELEASE_URL)) else: - QMessageBox.information(self, '', _("You're up to date!")) + QMessageBox.information(self, "", _("You're up to date!")) self.check_updates_button.setEnabled(True) @staticmethod def is_version_lower(version_a: str, version_b: str): - return version_a.replace('.', '') < version_b.replace('.', '') + return version_a.replace(".", "") < version_b.replace(".", "") diff --git a/buzz/widgets/audio_player.py b/buzz/widgets/audio_player.py index 4dfa683f..a26fa858 100644 --- a/buzz/widgets/audio_player.py +++ b/buzz/widgets/audio_player.py @@ -111,7 +111,7 @@ class AudioPlayer(QWidget): def update_time_label(self): position_time = QTime(0, 0).addMSecs(self.position).toString() duration_time = QTime(0, 0).addMSecs(self.duration).toString() - self.time_label.setText(f'{position_time} / {duration_time}') + self.time_label.setText(f"{position_time} / {duration_time}") def stop(self): self.media_player.stop() diff --git a/buzz/widgets/icon.py b/buzz/widgets/icon.py index f6e92584..ed5f92f7 100644 --- a/buzz/widgets/icon.py +++ b/buzz/widgets/icon.py @@ -6,8 +6,8 @@ from buzz.assets import get_asset_path # TODO: move icons to Qt resources: https://stackoverflow.com/a/52341917/9830227 class Icon(QIcon): - LIGHT_THEME_BACKGROUND = '#555' - DARK_THEME_BACKGROUND = '#EEE' + LIGHT_THEME_BACKGROUND = "#555" + DARK_THEME_BACKGROUND = "#EEE" def __init__(self, path: str, parent: QWidget): # Adapted from https://stackoverflow.com/questions/15123544/change-the-color-of-an-svg-in-qt @@ -23,18 +23,20 @@ class Icon(QIcon): super().__init__(pixmap) def get_color(self, is_dark_theme): - return self.DARK_THEME_BACKGROUND if is_dark_theme else self.LIGHT_THEME_BACKGROUND + return ( + self.DARK_THEME_BACKGROUND if is_dark_theme else self.LIGHT_THEME_BACKGROUND + ) class PlayIcon(Icon): def __init__(self, parent: QWidget): - super().__init__(get_asset_path('assets/play_arrow_black_24dp.svg'), parent) + super().__init__(get_asset_path("assets/play_arrow_black_24dp.svg"), parent) class PauseIcon(Icon): def __init__(self, parent: QWidget): - super().__init__(get_asset_path('assets/pause_black_24dp.svg'), parent) + super().__init__(get_asset_path("assets/pause_black_24dp.svg"), parent) -BUZZ_ICON_PATH = get_asset_path('assets/buzz.ico') -BUZZ_LARGE_ICON_PATH = get_asset_path('assets/buzz-icon-1024.png') +BUZZ_ICON_PATH = get_asset_path("assets/buzz.ico") +BUZZ_LARGE_ICON_PATH = get_asset_path("assets/buzz-icon-1024.png") diff --git a/buzz/widgets/line_edit.py b/buzz/widgets/line_edit.py index 2c33165d..fe1631ae 100644 --- a/buzz/widgets/line_edit.py +++ b/buzz/widgets/line_edit.py @@ -5,7 +5,7 @@ from PyQt6.QtWidgets import QLineEdit, QWidget class LineEdit(QLineEdit): - def __init__(self, default_text: str = '', parent: Optional[QWidget] = None): + def __init__(self, default_text: str = "", parent: Optional[QWidget] = None): super().__init__(default_text, parent) - if platform.system() == 'Darwin': - self.setStyleSheet('QLineEdit { padding: 4px }') + if platform.system() == "Darwin": + self.setStyleSheet("QLineEdit { padding: 4px }") diff --git a/buzz/widgets/menu_bar.py b/buzz/widgets/menu_bar.py index 65ae2f91..14e61ef5 100644 --- a/buzz/widgets/menu_bar.py +++ b/buzz/widgets/menu_bar.py @@ -18,16 +18,16 @@ class MenuBar(QMenuBar): openai_api_key_changed = pyqtSignal(str) default_export_file_name_changed = pyqtSignal(str) - def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str, - parent: QWidget): + def __init__( + self, shortcuts: Dict[str, str], default_export_file_name: str, parent: QWidget + ): super().__init__(parent) self.shortcuts = shortcuts self.default_export_file_name = default_export_file_name self.import_action = QAction(_("Import Media File..."), self) - self.import_action.triggered.connect( - self.on_import_action_triggered) + self.import_action.triggered.connect(self.on_import_action_triggered) about_action = QAction(f'{_("About")} {APP_NAME}', self) about_action.triggered.connect(self.on_about_action_triggered) @@ -56,22 +56,27 @@ class MenuBar(QMenuBar): about_dialog.open() def on_preferences_action_triggered(self): - preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts, - default_export_file_name=self.default_export_file_name, - parent=self) + preferences_dialog = PreferencesDialog( + shortcuts=self.shortcuts, + default_export_file_name=self.default_export_file_name, + parent=self, + ) preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed) preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed) preferences_dialog.default_export_file_name_changed.connect( - self.default_export_file_name_changed) + self.default_export_file_name_changed + ) preferences_dialog.open() def on_help_action_triggered(self): - webbrowser.open('https://chidiwilliams.github.io/buzz/docs') + webbrowser.open("https://chidiwilliams.github.io/buzz/docs") def set_shortcuts(self, shortcuts: Dict[str, str]): self.shortcuts = shortcuts self.import_action.setShortcut( - QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name])) + QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name]) + ) self.preferences_action.setShortcut( - QKeySequence.fromString(shortcuts[Shortcut.OPEN_PREFERENCES_WINDOW.name])) + QKeySequence.fromString(shortcuts[Shortcut.OPEN_PREFERENCES_WINDOW.name]) + ) diff --git a/buzz/widgets/model_download_progress_dialog.py b/buzz/widgets/model_download_progress_dialog.py index 6a0ce74e..727ed4d3 100644 --- a/buzz/widgets/model_download_progress_dialog.py +++ b/buzz/widgets/model_download_progress_dialog.py @@ -10,10 +10,17 @@ from buzz.model_loader import ModelType class ModelDownloadProgressDialog(QProgressDialog): - def __init__(self, model_type: ModelType, parent: Optional[QWidget] = None, modality=Qt.WindowModality.WindowModal): + def __init__( + self, + model_type: ModelType, + parent: Optional[QWidget] = None, + modality=Qt.WindowModality.WindowModal, + ): super().__init__(parent) - self.cancelable = model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP + self.cancelable = ( + model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP + ) self.start_time = datetime.now() self.setRange(0, 100) self.setMinimumDuration(0) @@ -21,7 +28,7 @@ class ModelDownloadProgressDialog(QProgressDialog): self.update_label_text(0) if not self.cancelable: - cancel_button = QPushButton('Cancel', self) + cancel_button = QPushButton("Cancel", self) cancel_button.setEnabled(False) self.setCancelButton(cancel_button) @@ -30,8 +37,8 @@ class ModelDownloadProgressDialog(QProgressDialog): if fraction_completed > 0: time_spent = (datetime.now() - self.start_time).total_seconds() time_left = (time_spent / fraction_completed) - time_spent - label_text += f', {humanize.naturaldelta(time_left)} remaining' - label_text += ')' + label_text += f", {humanize.naturaldelta(time_left)} remaining" + label_text += ")" self.setLabelText(label_text) diff --git a/buzz/widgets/model_type_combo_box.py b/buzz/widgets/model_type_combo_box.py index 6e8aea0e..674e69ab 100644 --- a/buzz/widgets/model_type_combo_box.py +++ b/buzz/widgets/model_type_combo_box.py @@ -10,8 +10,12 @@ from buzz.transcriber import LOADED_WHISPER_DLL class ModelTypeComboBox(QComboBox): changed = pyqtSignal(ModelType) - def __init__(self, model_types: Optional[List[ModelType]] = None, default_model: Optional[ModelType] = None, - parent: Optional[QWidget] = None): + def __init__( + self, + model_types: Optional[List[ModelType]] = None, + default_model: Optional[ModelType] = None, + parent: Optional[QWidget] = None, + ): super().__init__(parent) if model_types is None: diff --git a/buzz/widgets/openai_api_key_line_edit.py b/buzz/widgets/openai_api_key_line_edit.py index 01cfe413..034b84e2 100644 --- a/buzz/widgets/openai_api_key_line_edit.py +++ b/buzz/widgets/openai_api_key_line_edit.py @@ -16,15 +16,22 @@ class OpenAIAPIKeyLineEdit(LineEdit): self.key = key - self.visible_on_icon = Icon(get_asset_path('assets/visibility_FILL0_wght700_GRAD0_opsz48.svg'), self) - self.visible_off_icon = Icon(get_asset_path('assets/visibility_off_FILL0_wght700_GRAD0_opsz48.svg'), self) + self.visible_on_icon = Icon( + get_asset_path("assets/visibility_FILL0_wght700_GRAD0_opsz48.svg"), self + ) + self.visible_off_icon = Icon( + get_asset_path("assets/visibility_off_FILL0_wght700_GRAD0_opsz48.svg"), self + ) - self.setPlaceholderText('sk-...') + self.setPlaceholderText("sk-...") self.setEchoMode(QLineEdit.EchoMode.Password) self.textChanged.connect(self.on_openai_api_key_changed) - self.toggle_show_openai_api_key_action = self.addAction(self.visible_on_icon, - QLineEdit.ActionPosition.TrailingPosition) - self.toggle_show_openai_api_key_action.triggered.connect(self.on_toggle_show_action_triggered) + self.toggle_show_openai_api_key_action = self.addAction( + self.visible_on_icon, QLineEdit.ActionPosition.TrailingPosition + ) + self.toggle_show_openai_api_key_action.triggered.connect( + self.on_toggle_show_action_triggered + ) def on_toggle_show_action_triggered(self): if self.echoMode() == QLineEdit.EchoMode.Password: diff --git a/buzz/widgets/preferences_dialog/general_preferences_widget.py b/buzz/widgets/preferences_dialog/general_preferences_widget.py index 25e07afa..bbc1cfef 100644 --- a/buzz/widgets/preferences_dialog/general_preferences_widget.py +++ b/buzz/widgets/preferences_dialog/general_preferences_widget.py @@ -15,32 +15,40 @@ class GeneralPreferencesWidget(QWidget): openai_api_key_changed = pyqtSignal(str) default_export_file_name_changed = pyqtSignal(str) - def __init__(self, default_export_file_name: str, keyring_store=KeyringStore(), - parent: Optional[QWidget] = None): + def __init__( + self, + default_export_file_name: str, + keyring_store=KeyringStore(), + parent: Optional[QWidget] = None, + ): super().__init__(parent) self.openai_api_key = keyring_store.get_password( - KeyringStore.Key.OPENAI_API_KEY) + KeyringStore.Key.OPENAI_API_KEY + ) layout = QFormLayout(self) self.openai_api_key_line_edit = OpenAIAPIKeyLineEdit(self.openai_api_key, self) self.openai_api_key_line_edit.key_changed.connect( - self.on_openai_api_key_changed) + self.on_openai_api_key_changed + ) - self.test_openai_api_key_button = QPushButton('Test') + self.test_openai_api_key_button = QPushButton("Test") self.test_openai_api_key_button.clicked.connect( - self.on_click_test_openai_api_key_button) + self.on_click_test_openai_api_key_button + ) self.update_test_openai_api_key_button() - layout.addRow('OpenAI API Key', self.openai_api_key_line_edit) - layout.addRow('', self.test_openai_api_key_button) + layout.addRow("OpenAI API Key", self.openai_api_key_line_edit) + layout.addRow("", self.test_openai_api_key_button) default_export_file_name_line_edit = LineEdit(default_export_file_name, self) default_export_file_name_line_edit.textChanged.connect( - self.default_export_file_name_changed) + self.default_export_file_name_changed + ) default_export_file_name_line_edit.setMinimumWidth(200) - layout.addRow('Default export file name', default_export_file_name_line_edit) + layout.addRow("Default export file name", default_export_file_name_line_edit) self.setLayout(layout) @@ -60,12 +68,15 @@ class GeneralPreferencesWidget(QWidget): def on_test_openai_api_key_success(self): self.test_openai_api_key_button.setEnabled(True) - QMessageBox.information(self, 'OpenAI API Key Test', - 'Your API key is valid. Buzz will use this key to perform Whisper API transcriptions.') + QMessageBox.information( + self, + "OpenAI API Key Test", + "Your API key is valid. Buzz will use this key to perform Whisper API transcriptions.", + ) def on_test_openai_api_key_failure(self, error: str): self.test_openai_api_key_button.setEnabled(True) - QMessageBox.warning(self, 'OpenAI API Key Test', error) + QMessageBox.warning(self, "OpenAI API Key Test", error) def on_openai_api_key_changed(self, key: str): self.openai_api_key = key diff --git a/buzz/widgets/preferences_dialog/models_preferences_widget.py b/buzz/widgets/preferences_dialog/models_preferences_widget.py index db535d99..4b0a4b86 100644 --- a/buzz/widgets/preferences_dialog/models_preferences_widget.py +++ b/buzz/widgets/preferences_dialog/models_preferences_widget.py @@ -1,34 +1,55 @@ from typing import Optional from PyQt6.QtCore import Qt, QThreadPool -from PyQt6.QtWidgets import QWidget, QFormLayout, QTreeWidget, QTreeWidgetItem, \ - QPushButton, QMessageBox, QHBoxLayout +from PyQt6.QtWidgets import ( + QWidget, + QFormLayout, + QTreeWidget, + QTreeWidgetItem, + QPushButton, + QMessageBox, + QHBoxLayout, +) from buzz.locale import _ -from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel, ModelDownloader +from buzz.model_loader import ( + ModelType, + WhisperModelSize, + TranscriptionModel, + ModelDownloader, +) from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog from buzz.widgets.model_type_combo_box import ModelTypeComboBox class ModelsPreferencesWidget(QWidget): - def __init__(self, progress_dialog_modality=Qt.WindowModality.WindowModal, - parent: Optional[QWidget] = None): + def __init__( + self, + progress_dialog_modality=Qt.WindowModality.WindowModal, + parent: Optional[QWidget] = None, + ): super().__init__(parent) self.model_downloader: Optional[ModelDownloader] = None - self.model = TranscriptionModel(model_type=ModelType.WHISPER, - whisper_model_size=WhisperModelSize.TINY) + self.model = TranscriptionModel( + model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY + ) self.progress_dialog_modality = progress_dialog_modality self.progress_dialog: Optional[ModelDownloadProgressDialog] = None layout = QFormLayout() model_type_combo_box = ModelTypeComboBox( - model_types=[ModelType.WHISPER, ModelType.WHISPER_CPP, - ModelType.FASTER_WHISPER], - default_model=self.model.model_type, parent=self) + model_types=[ + ModelType.WHISPER, + ModelType.WHISPER_CPP, + ModelType.FASTER_WHISPER, + ], + default_model=self.model.model_type, + parent=self, + ) model_type_combo_box.changed.connect(self.on_model_type_changed) - layout.addRow('Group', model_type_combo_box) + layout.addRow("Group", model_type_combo_box) self.model_list_widget = QTreeWidget() self.model_list_widget.setColumnCount(1) @@ -37,20 +58,21 @@ class ModelsPreferencesWidget(QWidget): buttons_layout = QHBoxLayout() - self.download_button = QPushButton(_('Download')) - self.download_button.setObjectName('DownloadButton') + self.download_button = QPushButton(_("Download")) + self.download_button.setObjectName("DownloadButton") self.download_button.clicked.connect(self.on_download_button_clicked) buttons_layout.addWidget(self.download_button) - self.show_file_location_button = QPushButton(_('Show file location')) - self.show_file_location_button.setObjectName('ShowFileLocationButton') + self.show_file_location_button = QPushButton(_("Show file location")) + self.show_file_location_button.setObjectName("ShowFileLocationButton") self.show_file_location_button.clicked.connect( - self.on_show_file_location_button_clicked) + self.on_show_file_location_button_clicked + ) buttons_layout.addWidget(self.show_file_location_button) buttons_layout.addStretch(1) - self.delete_button = QPushButton(_('Delete')) - self.delete_button.setObjectName('DeleteButton') + self.delete_button = QPushButton(_("Delete")) + self.delete_button.setObjectName("DeleteButton") self.delete_button.clicked.connect(self.on_delete_button_clicked) buttons_layout.addWidget(self.delete_button) @@ -71,9 +93,10 @@ class ModelsPreferencesWidget(QWidget): @staticmethod def can_delete_model(model: TranscriptionModel): - return ((model.model_type == ModelType.WHISPER or - model.model_type == ModelType.WHISPER_CPP) and - model.get_local_model_path() is not None) + return ( + model.model_type == ModelType.WHISPER + or model.model_type == ModelType.WHISPER_CPP + ) and model.get_local_model_path() is not None def reset(self): # reset buttons @@ -85,20 +108,21 @@ class ModelsPreferencesWidget(QWidget): # reset model list self.model_list_widget.clear() downloaded_item = QTreeWidgetItem(self.model_list_widget) - downloaded_item.setText(0, _('Downloaded')) + downloaded_item.setText(0, _("Downloaded")) downloaded_item.setFlags( - downloaded_item.flags() & ~Qt.ItemFlag.ItemIsSelectable) + downloaded_item.flags() & ~Qt.ItemFlag.ItemIsSelectable + ) available_item = QTreeWidgetItem(self.model_list_widget) - available_item.setText(0, _('Available for Download')) - available_item.setFlags( - available_item.flags() & ~Qt.ItemFlag.ItemIsSelectable) + available_item.setText(0, _("Available for Download")) + available_item.setFlags(available_item.flags() & ~Qt.ItemFlag.ItemIsSelectable) self.model_list_widget.addTopLevelItems([downloaded_item, available_item]) self.model_list_widget.expandToDepth(2) self.model_list_widget.setHeaderHidden(True) self.model_list_widget.setAlternatingRowColors(True) for model_size in WhisperModelSize: - model = TranscriptionModel(model_type=self.model.model_type, - whisper_model_size=model_size) + model = TranscriptionModel( + model_type=self.model.model_type, whisper_model_size=model_size + ) model_path = model.get_local_model_path() parent = downloaded_item if model_path is not None else available_item item = QTreeWidgetItem(parent) @@ -115,7 +139,9 @@ class ModelsPreferencesWidget(QWidget): def on_download_button_clicked(self): self.progress_dialog = ModelDownloadProgressDialog( model_type=self.model.model_type, - modality=self.progress_dialog_modality, parent=self) + modality=self.progress_dialog_modality, + parent=self, + ) self.progress_dialog.canceled.connect(self.on_progress_dialog_canceled) self.download_button.setEnabled(False) @@ -128,8 +154,10 @@ class ModelsPreferencesWidget(QWidget): def on_delete_button_clicked(self): reply = QMessageBox.question( - self, _('Delete Model'), - _('Are you sure you want to delete the selected model?')) + self, + _("Delete Model"), + _("Are you sure you want to delete the selected model?"), + ) if reply == QMessageBox.StandardButton.Yes: self.model.delete_local_file() self.reset() @@ -147,7 +175,7 @@ class ModelsPreferencesWidget(QWidget): self.progress_dialog = None self.download_button.setEnabled(True) self.reset() - QMessageBox.warning(self, _('Error'), f'Download failed: {error}') + QMessageBox.warning(self, _("Error"), f"Download failed: {error}") def on_download_progress(self, progress: tuple): self.progress_dialog.set_value(float(progress[0]) / progress[1]) diff --git a/buzz/widgets/preferences_dialog/preferences_dialog.py b/buzz/widgets/preferences_dialog/preferences_dialog.py index fd68b7c7..d04a6cea 100644 --- a/buzz/widgets/preferences_dialog/preferences_dialog.py +++ b/buzz/widgets/preferences_dialog/preferences_dialog.py @@ -4,12 +4,15 @@ from PyQt6.QtCore import pyqtSignal from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QTabWidget, QDialogButtonBox from buzz.locale import _ -from buzz.widgets.preferences_dialog.general_preferences_widget import \ - GeneralPreferencesWidget -from buzz.widgets.preferences_dialog.models_preferences_widget import \ - ModelsPreferencesWidget -from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import \ - ShortcutsEditorPreferencesWidget +from buzz.widgets.preferences_dialog.general_preferences_widget import ( + GeneralPreferencesWidget, +) +from buzz.widgets.preferences_dialog.models_preferences_widget import ( + ModelsPreferencesWidget, +) +from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import ( + ShortcutsEditorPreferencesWidget, +) class PreferencesDialog(QDialog): @@ -17,31 +20,38 @@ class PreferencesDialog(QDialog): openai_api_key_changed = pyqtSignal(str) default_export_file_name_changed = pyqtSignal(str) - def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str, - parent: Optional[QWidget] = None) -> None: + def __init__( + self, + shortcuts: Dict[str, str], + default_export_file_name: str, + parent: Optional[QWidget] = None, + ) -> None: super().__init__(parent) - self.setWindowTitle('Preferences') + self.setWindowTitle("Preferences") layout = QVBoxLayout(self) tab_widget = QTabWidget(self) general_tab_widget = GeneralPreferencesWidget( - default_export_file_name=default_export_file_name, parent=self) + default_export_file_name=default_export_file_name, parent=self + ) general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed) general_tab_widget.default_export_file_name_changed.connect( - self.default_export_file_name_changed) - tab_widget.addTab(general_tab_widget, _('General')) + self.default_export_file_name_changed + ) + tab_widget.addTab(general_tab_widget, _("General")) models_tab_widget = ModelsPreferencesWidget(parent=self) - tab_widget.addTab(models_tab_widget, _('Models')) + tab_widget.addTab(models_tab_widget, _("Models")) shortcuts_table_widget = ShortcutsEditorPreferencesWidget(shortcuts, self) shortcuts_table_widget.shortcuts_changed.connect(self.shortcuts_changed) - tab_widget.addTab(shortcuts_table_widget, _('Shortcuts')) + tab_widget.addTab(shortcuts_table_widget, _("Shortcuts")) - button_box = QDialogButtonBox(QDialogButtonBox.StandardButton( - QDialogButtonBox.StandardButton.Ok), self) + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self + ) button_box.accepted.connect(self.accept) button_box.rejected.connect(self.reject) diff --git a/buzz/widgets/preferences_dialog/shortcuts_editor_preferences_widget.py b/buzz/widgets/preferences_dialog/shortcuts_editor_preferences_widget.py index f0da8053..47fdf0b3 100644 --- a/buzz/widgets/preferences_dialog/shortcuts_editor_preferences_widget.py +++ b/buzz/widgets/preferences_dialog/shortcuts_editor_preferences_widget.py @@ -18,12 +18,13 @@ class ShortcutsEditorPreferencesWidget(QWidget): self.layout = QFormLayout(self) for shortcut in Shortcut: - sequence_edit = SequenceEdit(shortcuts.get(shortcut.name, ''), self) + sequence_edit = SequenceEdit(shortcuts.get(shortcut.name, ""), self) sequence_edit.keySequenceChanged.connect( - self.get_key_sequence_changed(shortcut.name)) + self.get_key_sequence_changed(shortcut.name) + ) self.layout.addRow(shortcut.description, sequence_edit) - reset_to_defaults_button = QPushButton('Reset to Defaults', self) + reset_to_defaults_button = QPushButton("Reset to Defaults", self) reset_to_defaults_button.setDefault(False) reset_to_defaults_button.setAutoDefault(False) reset_to_defaults_button.clicked.connect(self.reset_to_defaults) @@ -41,8 +42,9 @@ class ShortcutsEditorPreferencesWidget(QWidget): self.shortcuts = Shortcut.get_default_shortcuts() for i, shortcut in enumerate(Shortcut): - sequence_edit = self.layout.itemAt(i, - QFormLayout.ItemRole.FieldRole).widget() + sequence_edit = self.layout.itemAt( + i, QFormLayout.ItemRole.FieldRole + ).widget() assert isinstance(sequence_edit, SequenceEdit) sequence_edit.setKeySequence(QKeySequence(self.shortcuts[shortcut.name])) diff --git a/buzz/widgets/sequence_edit.py b/buzz/widgets/sequence_edit.py index 59770788..6af5799b 100644 --- a/buzz/widgets/sequence_edit.py +++ b/buzz/widgets/sequence_edit.py @@ -10,8 +10,8 @@ class SequenceEdit(QKeySequenceEdit): def __init__(self, sequence: str, parent: Optional[QWidget] = None): super().__init__(sequence, parent) self.setClearButtonEnabled(True) - if platform.system() == 'Darwin': - self.setStyleSheet('QLineEdit:focus { border: 2px solid #4d90fe; }') + if platform.system() == "Darwin": + self.setStyleSheet("QLineEdit:focus { border: 2px solid #4d90fe; }") def keyPressEvent(self, event: QtGui.QKeyEvent) -> None: key = event.key() @@ -23,7 +23,12 @@ class SequenceEdit(QKeySequenceEdit): return # Ignore pressing *only* modifier keys - if key == Qt.Key.Key_Control or key == Qt.Key.Key_Shift or key == Qt.Key.Key_Alt or key == Qt.Key.Key_Meta: + if ( + key == Qt.Key.Key_Control + or key == Qt.Key.Key_Shift + or key == Qt.Key.Key_Alt + or key == Qt.Key.Key_Meta + ): return super().keyPressEvent(event) diff --git a/buzz/widgets/toolbar.py b/buzz/widgets/toolbar.py index b2bed7ea..36fde2a9 100644 --- a/buzz/widgets/toolbar.py +++ b/buzz/widgets/toolbar.py @@ -11,7 +11,7 @@ class ToolBar(QToolBar): super().__init__(parent) self.setIconSize(QSize(18, 18)) - self.setStyleSheet('QToolButton{margin: 6px 3px;}') + self.setStyleSheet("QToolButton{margin: 6px 3px;}") self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly) def addAction(self, action: QtGui.QAction) -> None: @@ -23,6 +23,7 @@ class ToolBar(QToolBar): self.fix_spacing_on_mac() def fix_spacing_on_mac(self): - if platform.system() == 'Darwin': + if platform.system() == "Darwin": self.widgetForAction(self.actions()[0]).setStyleSheet( - 'QToolButton { margin-left: 9px; margin-right: 1px; }') + "QToolButton { margin-left: 9px; margin-right: 1px; }" + ) diff --git a/buzz/widgets/transcriber/advanced_settings_button.py b/buzz/widgets/transcriber/advanced_settings_button.py index 84100559..4bc92657 100644 --- a/buzz/widgets/transcriber/advanced_settings_button.py +++ b/buzz/widgets/transcriber/advanced_settings_button.py @@ -5,4 +5,4 @@ from PyQt6.QtWidgets import QPushButton, QWidget class AdvancedSettingsButton(QPushButton): def __init__(self, parent: Optional[QWidget]) -> None: - super().__init__('Advanced...', parent) + super().__init__("Advanced...", parent) diff --git a/buzz/widgets/transcriber/advanced_settings_dialog.py b/buzz/widgets/transcriber/advanced_settings_dialog.py index 172760cc..be7fe099 100644 --- a/buzz/widgets/transcriber/advanced_settings_dialog.py +++ b/buzz/widgets/transcriber/advanced_settings_dialog.py @@ -1,6 +1,11 @@ from PyQt6.QtCore import pyqtSignal -from PyQt6.QtWidgets import QDialog, QWidget, QDialogButtonBox, QFormLayout, \ - QPlainTextEdit +from PyQt6.QtWidgets import ( + QDialog, + QWidget, + QDialogButtonBox, + QFormLayout, + QPlainTextEdit, +) from buzz.widgets.transcriber.temperature_validator import TemperatureValidator from buzz.locale import _ @@ -13,39 +18,48 @@ class AdvancedSettingsDialog(QDialog): transcription_options: TranscriptionOptions transcription_options_changed = pyqtSignal(TranscriptionOptions) - def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None): + def __init__( + self, transcription_options: TranscriptionOptions, parent: QWidget | None = None + ): super().__init__(parent) self.transcription_options = transcription_options - self.setWindowTitle(_('Advanced Settings')) + self.setWindowTitle(_("Advanced Settings")) - button_box = QDialogButtonBox(QDialogButtonBox.StandardButton( - QDialogButtonBox.StandardButton.Ok), self) + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self + ) button_box.accepted.connect(self.accept) layout = QFormLayout(self) - default_temperature_text = ', '.join( - [str(temp) for temp in transcription_options.temperature]) + default_temperature_text = ", ".join( + [str(temp) for temp in transcription_options.temperature] + ) self.temperature_line_edit = LineEdit(default_temperature_text, self) self.temperature_line_edit.setPlaceholderText( - _('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"')) + _('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"') + ) self.temperature_line_edit.setMinimumWidth(170) - self.temperature_line_edit.textChanged.connect( - self.on_temperature_changed) + self.temperature_line_edit.textChanged.connect(self.on_temperature_changed) self.temperature_line_edit.setValidator(TemperatureValidator(self)) - self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER) + self.temperature_line_edit.setEnabled( + transcription_options.model.model_type == ModelType.WHISPER + ) self.initial_prompt_text_edit = QPlainTextEdit( - transcription_options.initial_prompt, self) + transcription_options.initial_prompt, self + ) self.initial_prompt_text_edit.textChanged.connect( - self.on_initial_prompt_changed) + self.on_initial_prompt_changed + ) self.initial_prompt_text_edit.setEnabled( - transcription_options.model.model_type == ModelType.WHISPER) + transcription_options.model.model_type == ModelType.WHISPER + ) - layout.addRow(_('Temperature:'), self.temperature_line_edit) - layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit) + layout.addRow(_("Temperature:"), self.temperature_line_edit) + layout.addRow(_("Initial Prompt:"), self.initial_prompt_text_edit) layout.addWidget(button_box) self.setLayout(layout) @@ -53,12 +67,14 @@ class AdvancedSettingsDialog(QDialog): def on_temperature_changed(self, text: str): try: - temperatures = [float(temp.strip()) for temp in text.split(',')] + temperatures = [float(temp.strip()) for temp in text.split(",")] self.transcription_options.temperature = tuple(temperatures) self.transcription_options_changed.emit(self.transcription_options) except ValueError: pass def on_initial_prompt_changed(self): - self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText() + self.transcription_options.initial_prompt = ( + self.initial_prompt_text_edit.toPlainText() + ) self.transcription_options_changed.emit(self.transcription_options) diff --git a/buzz/widgets/transcriber/file_transcriber_widget.py b/buzz/widgets/transcriber/file_transcriber_widget.py index ebcd6629..2627530d 100644 --- a/buzz/widgets/transcriber/file_transcriber_widget.py +++ b/buzz/widgets/transcriber/file_transcriber_widget.py @@ -2,8 +2,14 @@ from typing import Optional, List, Tuple from PyQt6 import QtGui from PyQt6.QtCore import pyqtSignal, Qt, QThreadPool -from PyQt6.QtWidgets import QWidget, QVBoxLayout, QCheckBox, QFormLayout, QHBoxLayout, \ - QPushButton +from PyQt6.QtWidgets import ( + QWidget, + QVBoxLayout, + QCheckBox, + QFormLayout, + QHBoxLayout, + QPushButton, +) from buzz.dialogs import show_model_download_error_dialog from buzz.locale import _ @@ -11,11 +17,17 @@ from buzz.model_loader import ModelDownloader, TranscriptionModel, ModelType from buzz.paths import file_paths_as_title from buzz.settings.settings import Settings from buzz.store.keyring_store import KeyringStore -from buzz.transcriber import FileTranscriptionOptions, TranscriptionOptions, Task, \ - DEFAULT_WHISPER_TEMPERATURE, OutputFormat +from buzz.transcriber import ( + FileTranscriptionOptions, + TranscriptionOptions, + Task, + DEFAULT_WHISPER_TEMPERATURE, + OutputFormat, +) from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog -from buzz.widgets.transcriber.transcription_options_group_box import \ - TranscriptionOptionsGroupBox +from buzz.widgets.transcriber.transcription_options_group_box import ( + TranscriptionOptionsGroupBox, +) class FileTranscriberWidget(QWidget): @@ -29,74 +41,100 @@ class FileTranscriberWidget(QWidget): openai_access_token_changed = pyqtSignal(str) settings = Settings() - def __init__(self, file_paths: List[str], - default_output_file_name: str, - parent: Optional[QWidget] = None, - flags: Qt.WindowType = Qt.WindowType.Widget) -> None: + def __init__( + self, + file_paths: List[str], + default_output_file_name: str, + parent: Optional[QWidget] = None, + flags: Qt.WindowType = Qt.WindowType.Widget, + ) -> None: super().__init__(parent, flags) self.setWindowTitle(file_paths_as_title(file_paths)) openai_access_token = KeyringStore().get_password( - KeyringStore.Key.OPENAI_API_KEY) + KeyringStore.Key.OPENAI_API_KEY + ) self.file_paths = file_paths default_language = self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='') + key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value="" + ) self.transcription_options = TranscriptionOptions( openai_access_token=openai_access_token, - model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL, - default_value=TranscriptionModel()), - task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK, - default_value=Task.TRANSCRIBE), - language=default_language if default_language != '' else None, + model=self.settings.value( + key=Settings.Key.FILE_TRANSCRIBER_MODEL, + default_value=TranscriptionModel(), + ), + task=self.settings.value( + key=Settings.Key.FILE_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE + ), + language=default_language if default_language != "" else None, initial_prompt=self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''), + key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value="" + ), temperature=self.settings.value( key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, - default_value=DEFAULT_WHISPER_TEMPERATURE), + default_value=DEFAULT_WHISPER_TEMPERATURE, + ), word_level_timings=self.settings.value( key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, - default_value=False)) + default_value=False, + ), + ) default_export_format_states: List[str] = self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, - default_value=[]) + key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, default_value=[] + ) self.file_transcription_options = FileTranscriptionOptions( file_paths=self.file_paths, - output_formats=set([OutputFormat(output_format) for output_format in - default_export_format_states]), - default_output_file_name=default_output_file_name) + output_formats=set( + [ + OutputFormat(output_format) + for output_format in default_export_format_states + ] + ), + default_output_file_name=default_output_file_name, + ) layout = QVBoxLayout(self) transcription_options_group_box = TranscriptionOptionsGroupBox( - default_transcription_options=self.transcription_options, parent=self) + default_transcription_options=self.transcription_options, parent=self + ) transcription_options_group_box.transcription_options_changed.connect( - self.on_transcription_options_changed) + self.on_transcription_options_changed + ) - self.word_level_timings_checkbox = QCheckBox(_('Word-level timings')) + self.word_level_timings_checkbox = QCheckBox(_("Word-level timings")) self.word_level_timings_checkbox.setChecked( - self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, - default_value=False)) + self.settings.value( + key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, + default_value=False, + ) + ) self.word_level_timings_checkbox.stateChanged.connect( - self.on_word_level_timings_changed) + self.on_word_level_timings_changed + ) file_transcription_layout = QFormLayout() - file_transcription_layout.addRow('', self.word_level_timings_checkbox) + file_transcription_layout.addRow("", self.word_level_timings_checkbox) export_format_layout = QHBoxLayout() for output_format in OutputFormat: - export_format_checkbox = QCheckBox(f'{output_format.value.upper()}', - parent=self) + export_format_checkbox = QCheckBox( + f"{output_format.value.upper()}", parent=self + ) export_format_checkbox.setChecked( - output_format in self.file_transcription_options.output_formats) + output_format in self.file_transcription_options.output_formats + ) export_format_checkbox.stateChanged.connect( - self.get_on_checkbox_state_changed_callback(output_format)) + self.get_on_checkbox_state_changed_callback(output_format) + ) export_format_layout.addWidget(export_format_checkbox) - file_transcription_layout.addRow('Export:', export_format_layout) + file_transcription_layout.addRow("Export:", export_format_layout) - self.run_button = QPushButton(_('Run'), self) + self.run_button = QPushButton(_("Run"), self) self.run_button.setDefault(True) self.run_button.clicked.connect(self.on_click_run) @@ -116,15 +154,19 @@ class FileTranscriberWidget(QWidget): return on_checkbox_state_changed - def on_transcription_options_changed(self, - transcription_options: TranscriptionOptions): + def on_transcription_options_changed( + self, transcription_options: TranscriptionOptions + ): self.transcription_options = transcription_options self.word_level_timings_checkbox.setDisabled( - self.transcription_options.model.model_type == ModelType.HUGGING_FACE or - self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API) - if self.transcription_options.openai_access_token != '': + self.transcription_options.model.model_type == ModelType.HUGGING_FACE + or self.transcription_options.model.model_type + == ModelType.OPEN_AI_WHISPER_API + ) + if self.transcription_options.openai_access_token != "": self.openai_access_token_changed.emit( - self.transcription_options.openai_access_token) + self.transcription_options.openai_access_token + ) def on_click_run(self): self.run_button.setDisabled(True) @@ -143,8 +185,9 @@ class FileTranscriberWidget(QWidget): def on_model_loaded(self, model_path: str): self.reset_transcriber_controls() - self.triggered.emit((self.transcription_options, - self.file_transcription_options, model_path)) + self.triggered.emit( + (self.transcription_options, self.file_transcription_options, model_path) + ) self.close() def on_download_model_progress(self, progress: Tuple[float, float]): @@ -152,13 +195,16 @@ class FileTranscriberWidget(QWidget): if self.model_download_progress_dialog is None: self.model_download_progress_dialog = ModelDownloadProgressDialog( - model_type=self.transcription_options.model.model_type, parent=self) + model_type=self.transcription_options.model.model_type, parent=self + ) self.model_download_progress_dialog.canceled.connect( - self.on_cancel_model_progress_dialog) + self.on_cancel_model_progress_dialog + ) if self.model_download_progress_dialog is not None: self.model_download_progress_dialog.set_value( - fraction_completed=current_size / total_size) + fraction_completed=current_size / total_size + ) def on_download_model_error(self, error: str): self.reset_model_download() @@ -179,26 +225,41 @@ class FileTranscriberWidget(QWidget): self.model_download_progress_dialog = None def on_word_level_timings_changed(self, value: int): - self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value + self.transcription_options.word_level_timings = ( + value == Qt.CheckState.Checked.value + ) def closeEvent(self, event: QtGui.QCloseEvent) -> None: if self.model_loader is not None: self.model_loader.cancel() - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE, - self.transcription_options.language) - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK, - self.transcription_options.task) - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, - self.transcription_options.temperature) - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, - self.transcription_options.initial_prompt) - self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL, - self.transcription_options.model) - self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, - value=self.transcription_options.word_level_timings) - self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, - value=[export_format.value for export_format in - self.file_transcription_options.output_formats]) + self.settings.set_value( + Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language + ) + self.settings.set_value( + Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task + ) + self.settings.set_value( + Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, + self.transcription_options.temperature, + ) + self.settings.set_value( + Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, + self.transcription_options.initial_prompt, + ) + self.settings.set_value( + Settings.Key.FILE_TRANSCRIBER_MODEL, self.transcription_options.model + ) + self.settings.set_value( + key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, + value=self.transcription_options.word_level_timings, + ) + self.settings.set_value( + key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, + value=[ + export_format.value + for export_format in self.file_transcription_options.output_formats + ], + ) super().closeEvent(event) diff --git a/buzz/widgets/transcriber/hugging_face_search_line_edit.py b/buzz/widgets/transcriber/hugging_face_search_line_edit.py index 4c98a516..017c579c 100644 --- a/buzz/widgets/transcriber/hugging_face_search_line_edit.py +++ b/buzz/widgets/transcriber/hugging_face_search_line_edit.py @@ -2,8 +2,17 @@ import json import logging from typing import Optional -from PyQt6.QtCore import pyqtSignal, QTimer, Qt, QMetaObject, QUrl, QUrlQuery, QPoint, \ - QObject, QEvent +from PyQt6.QtCore import ( + pyqtSignal, + QTimer, + Qt, + QMetaObject, + QUrl, + QUrlQuery, + QPoint, + QObject, + QEvent, +) from PyQt6.QtGui import QKeyEvent from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem @@ -16,12 +25,15 @@ class HuggingFaceSearchLineEdit(LineEdit): model_selected = pyqtSignal(str) popup: QListWidget - def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None, - parent: Optional[QWidget] = None): - super().__init__('', parent) + def __init__( + self, + network_access_manager: Optional[QNetworkAccessManager] = None, + parent: Optional[QWidget] = None, + ): + super().__init__("", parent) self.setMinimumWidth(150) - self.setPlaceholderText('openai/whisper-tiny') + self.setPlaceholderText("openai/whisper-tiny") self.timer = QTimer(self) self.timer.setSingleShot(True) @@ -56,7 +68,7 @@ class HuggingFaceSearchLineEdit(LineEdit): item = self.popup.currentItem() self.setText(item.text()) - QMetaObject.invokeMethod(self, 'returnPressed') + QMetaObject.invokeMethod(self, "returnPressed") self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole)) def fetch_models(self): @@ -79,7 +91,9 @@ class HuggingFaceSearchLineEdit(LineEdit): def on_request_response(self, network_reply: QNetworkReply): if network_reply.error() != QNetworkReply.NetworkError.NoError: - logging.debug('Error fetching Hugging Face models: %s', network_reply.error()) + logging.debug( + "Error fetching Hugging Face models: %s", network_reply.error() + ) return models = json.loads(network_reply.readAll().data()) @@ -88,7 +102,7 @@ class HuggingFaceSearchLineEdit(LineEdit): self.popup.clear() for model in models: - model_id = model.get('id') + model_id = model.get("id") item = QListWidgetItem(self.popup) item.setText(model_id) @@ -96,14 +110,16 @@ class HuggingFaceSearchLineEdit(LineEdit): self.popup.setCurrentItem(self.popup.item(0)) self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20) - self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll + self.popup.setFixedHeight( + self.popup.sizeHintForRow(0) * min(len(models), 8) + ) # show max 8 models, then scroll self.popup.setUpdatesEnabled(True) self.popup.move(self.mapToGlobal(QPoint(0, self.height()))) self.popup.setFocus() self.popup.show() def eventFilter(self, target: QObject, event: QEvent): - if hasattr(self, 'popup') is False or target != self.popup: + if hasattr(self, "popup") is False or target != self.popup: return False if event.type() == QEvent.Type.MouseButtonPress: @@ -123,8 +139,14 @@ class HuggingFaceSearchLineEdit(LineEdit): self.popup.hide() return True - if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp, - Qt.Key.Key_PageDown]: + if key in [ + Qt.Key.Key_Up, + Qt.Key.Key_Down, + Qt.Key.Key_Home, + Qt.Key.Key_End, + Qt.Key.Key_PageUp, + Qt.Key.Key_PageDown, + ]: return False self.setFocus() diff --git a/buzz/widgets/transcriber/languages_combo_box.py b/buzz/widgets/transcriber/languages_combo_box.py index 786a482d..3355af63 100644 --- a/buzz/widgets/transcriber/languages_combo_box.py +++ b/buzz/widgets/transcriber/languages_combo_box.py @@ -9,20 +9,25 @@ from buzz.transcriber import LANGUAGES class LanguagesComboBox(QComboBox): """LanguagesComboBox displays a list of languages available to use with Whisper""" + # language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language" languageChanged = pyqtSignal(str) - def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None: + def __init__( + self, default_language: Optional[str], parent: Optional[QWidget] = None + ) -> None: super().__init__(parent) whisper_languages = sorted( - [(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1]) - self.languages = [('', _('Detect Language'))] + whisper_languages + [(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], + key=lambda lang: lang[1], + ) + self.languages = [("", _("Detect Language"))] + whisper_languages self.addItems([lang[1] for lang in self.languages]) self.currentIndexChanged.connect(self.on_index_changed) - default_language_key = default_language if default_language != '' else None + default_language_key = default_language if default_language != "" else None for i, lang in enumerate(self.languages): if lang[0] == default_language_key: self.setCurrentIndex(i) diff --git a/buzz/widgets/transcriber/tasks_combo_box.py b/buzz/widgets/transcriber/tasks_combo_box.py index 2b2c02cc..f889a2be 100644 --- a/buzz/widgets/transcriber/tasks_combo_box.py +++ b/buzz/widgets/transcriber/tasks_combo_box.py @@ -8,6 +8,7 @@ from buzz.transcriber import Task class TasksComboBox(QComboBox): """TasksComboBox displays a list of tasks available to use with Whisper""" + taskChanged = pyqtSignal(Task) def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None: diff --git a/buzz/widgets/transcriber/temperature_validator.py b/buzz/widgets/transcriber/temperature_validator.py index 29986a62..3fcc97ec 100644 --- a/buzz/widgets/transcriber/temperature_validator.py +++ b/buzz/widgets/transcriber/temperature_validator.py @@ -8,10 +8,12 @@ class TemperatureValidator(QValidator): def __init__(self, parent: Optional[QObject] = ...) -> None: super().__init__(parent) - def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]: + def validate( + self, text: str, cursor_position: int + ) -> Tuple["QValidator.State", str, int]: try: - temp_strings = [temp.strip() for temp in text.split(',')] - if temp_strings[-1] == '': + temp_strings = [temp.strip() for temp in text.split(",")] + if temp_strings[-1] == "": return QValidator.State.Intermediate, text, cursor_position _ = [float(temp) for temp in temp_strings] return QValidator.State.Acceptable, text, cursor_position diff --git a/buzz/widgets/transcriber/transcription_options_group_box.py b/buzz/widgets/transcriber/transcription_options_group_box.py index 3ec1f676..8eb592e6 100644 --- a/buzz/widgets/transcriber/transcription_options_group_box.py +++ b/buzz/widgets/transcriber/transcription_options_group_box.py @@ -10,8 +10,9 @@ from buzz.widgets.model_type_combo_box import ModelTypeComboBox from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit from buzz.widgets.transcriber.advanced_settings_button import AdvancedSettingsButton from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog -from buzz.widgets.transcriber.hugging_face_search_line_edit import \ - HuggingFaceSearchLineEdit +from buzz.widgets.transcriber.hugging_face_search_line_edit import ( + HuggingFaceSearchLineEdit, +) from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox from buzz.widgets.transcriber.tasks_combo_box import TasksComboBox @@ -21,64 +22,70 @@ class TranscriptionOptionsGroupBox(QGroupBox): transcription_options_changed = pyqtSignal(TranscriptionOptions) def __init__( - self, - default_transcription_options: TranscriptionOptions = TranscriptionOptions(), - model_types: Optional[List[ModelType]] = None, - parent: Optional[QWidget] = None): - super().__init__(title='', parent=parent) + self, + default_transcription_options: TranscriptionOptions = TranscriptionOptions(), + model_types: Optional[List[ModelType]] = None, + parent: Optional[QWidget] = None, + ): + super().__init__(title="", parent=parent) self.transcription_options = default_transcription_options self.form_layout = QFormLayout(self) self.tasks_combo_box = TasksComboBox( - default_task=self.transcription_options.task, - parent=self) + default_task=self.transcription_options.task, parent=self + ) self.tasks_combo_box.taskChanged.connect(self.on_task_changed) self.languages_combo_box = LanguagesComboBox( - default_language=self.transcription_options.language, - parent=self) - self.languages_combo_box.languageChanged.connect( - self.on_language_changed) + default_language=self.transcription_options.language, parent=self + ) + self.languages_combo_box.languageChanged.connect(self.on_language_changed) self.advanced_settings_button = AdvancedSettingsButton(self) - self.advanced_settings_button.clicked.connect( - self.open_advanced_settings) + self.advanced_settings_button.clicked.connect(self.open_advanced_settings) self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit() self.hugging_face_search_line_edit.model_selected.connect( - self.on_hugging_face_model_changed) + self.on_hugging_face_model_changed + ) - self.model_type_combo_box = ModelTypeComboBox(model_types=model_types, - default_model=default_transcription_options.model.model_type, - parent=self) + self.model_type_combo_box = ModelTypeComboBox( + model_types=model_types, + default_model=default_transcription_options.model.model_type, + parent=self, + ) self.model_type_combo_box.changed.connect(self.on_model_type_changed) self.whisper_model_size_combo_box = QComboBox(self) self.whisper_model_size_combo_box.addItems( - [size.value.title() for size in WhisperModelSize]) + [size.value.title() for size in WhisperModelSize] + ) if default_transcription_options.model.whisper_model_size is not None: self.whisper_model_size_combo_box.setCurrentText( - default_transcription_options.model.whisper_model_size.value.title()) + default_transcription_options.model.whisper_model_size.value.title() + ) self.whisper_model_size_combo_box.currentTextChanged.connect( - self.on_whisper_model_size_changed) + self.on_whisper_model_size_changed + ) self.openai_access_token_edit = OpenAIAPIKeyLineEdit( - key=default_transcription_options.openai_access_token, - parent=self) + key=default_transcription_options.openai_access_token, parent=self + ) self.openai_access_token_edit.key_changed.connect( - self.on_openai_access_token_edit_changed) + self.on_openai_access_token_edit_changed + ) - self.form_layout.addRow(_('Model:'), self.model_type_combo_box) - self.form_layout.addRow('', self.whisper_model_size_combo_box) - self.form_layout.addRow('', self.hugging_face_search_line_edit) - self.form_layout.addRow('Access Token:', self.openai_access_token_edit) - self.form_layout.addRow(_('Task:'), self.tasks_combo_box) - self.form_layout.addRow(_('Language:'), self.languages_combo_box) + self.form_layout.addRow(_("Model:"), self.model_type_combo_box) + self.form_layout.addRow("", self.whisper_model_size_combo_box) + self.form_layout.addRow("", self.hugging_face_search_line_edit) + self.form_layout.addRow("Access Token:", self.openai_access_token_edit) + self.form_layout.addRow(_("Task:"), self.tasks_combo_box) + self.form_layout.addRow(_("Language:"), self.languages_combo_box) self.reset_visible_rows() - self.form_layout.addRow('', self.advanced_settings_button) + self.form_layout.addRow("", self.advanced_settings_button) self.setLayout(self.form_layout) @@ -104,26 +111,33 @@ class TranscriptionOptionsGroupBox(QGroupBox): def open_advanced_settings(self): dialog = AdvancedSettingsDialog( - transcription_options=self.transcription_options, parent=self) + transcription_options=self.transcription_options, parent=self + ) dialog.transcription_options_changed.connect( - self.on_transcription_options_changed) + self.on_transcription_options_changed + ) dialog.exec() - def on_transcription_options_changed(self, - transcription_options: TranscriptionOptions): + def on_transcription_options_changed( + self, transcription_options: TranscriptionOptions + ): self.transcription_options = transcription_options self.transcription_options_changed.emit(transcription_options) def reset_visible_rows(self): model_type = self.transcription_options.model.model_type - self.form_layout.setRowVisible(self.hugging_face_search_line_edit, - model_type == ModelType.HUGGING_FACE) - self.form_layout.setRowVisible(self.whisper_model_size_combo_box, - (model_type == ModelType.WHISPER) or ( - model_type == ModelType.WHISPER_CPP) or ( - model_type == ModelType.FASTER_WHISPER)) - self.form_layout.setRowVisible(self.openai_access_token_edit, - model_type == ModelType.OPEN_AI_WHISPER_API) + self.form_layout.setRowVisible( + self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE + ) + self.form_layout.setRowVisible( + self.whisper_model_size_combo_box, + (model_type == ModelType.WHISPER) + or (model_type == ModelType.WHISPER_CPP) + or (model_type == ModelType.FASTER_WHISPER), + ) + self.form_layout.setRowVisible( + self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API + ) def on_model_type_changed(self, model_type: ModelType): self.transcription_options.model.model_type = model_type diff --git a/buzz/widgets/transcription_segments_editor_widget.py b/buzz/widgets/transcription_segments_editor_widget.py index 0796231d..9bdbebe0 100644 --- a/buzz/widgets/transcription_segments_editor_widget.py +++ b/buzz/widgets/transcription_segments_editor_widget.py @@ -27,9 +27,10 @@ class TranscriptionSegmentsEditorWidget(QTableWidget): self.setColumnCount(3) self.verticalHeader().hide() - self.setHorizontalHeaderLabels([_('Start'), _('End'), _('Text')]) - self.horizontalHeader().setSectionResizeMode(2, - QHeaderView.ResizeMode.ResizeToContents) + self.setHorizontalHeaderLabels([_("Start"), _("End"), _("Text")]) + self.horizontalHeader().setSectionResizeMode( + 2, QHeaderView.ResizeMode.ResizeToContents + ) self.setSelectionMode(QTableWidget.SelectionMode.SingleSelection) for segment in segments: @@ -38,12 +39,18 @@ class TranscriptionSegmentsEditorWidget(QTableWidget): start_item = QTableWidgetItem(to_timestamp(segment.start)) start_item.setFlags( - start_item.flags() & ~Qt.ItemFlag.ItemIsEditable & ~Qt.ItemFlag.ItemIsSelectable) + start_item.flags() + & ~Qt.ItemFlag.ItemIsEditable + & ~Qt.ItemFlag.ItemIsSelectable + ) self.setItem(row_index, self.Column.START.value, start_item) end_item = QTableWidgetItem(to_timestamp(segment.end)) end_item.setFlags( - end_item.flags() & ~Qt.ItemFlag.ItemIsEditable & ~Qt.ItemFlag.ItemIsSelectable) + end_item.flags() + & ~Qt.ItemFlag.ItemIsEditable + & ~Qt.ItemFlag.ItemIsSelectable + ) self.setItem(row_index, self.Column.END.value, end_item) text_item = QTableWidgetItem(segment.text) @@ -61,5 +68,4 @@ class TranscriptionSegmentsEditorWidget(QTableWidget): def on_item_selection_changed(self): ranges = self.selectedRanges() - self.segment_index_selected.emit( - ranges[0].topRow() if len(ranges) > 0 else -1) + self.segment_index_selected.emit(ranges[0].topRow() if len(ranges) > 0 else -1) diff --git a/buzz/widgets/transcription_tasks_table_widget.py b/buzz/widgets/transcription_tasks_table_widget.py index 7edaf40a..cd79fcd3 100644 --- a/buzz/widgets/transcription_tasks_table_widget.py +++ b/buzz/widgets/transcription_tasks_table_widget.py @@ -30,13 +30,12 @@ class TranscriptionTasksTableWidget(QTableWidget): self.setColumnHidden(0, True) self.verticalHeader().hide() - self.setHorizontalHeaderLabels([_('ID'), _('File Name'), _('Status')]) + self.setHorizontalHeaderLabels([_("ID"), _("File Name"), _("Status")]) self.setColumnWidth(self.Column.FILE_NAME.value, 250) self.setColumnWidth(self.Column.STATUS.value, 180) self.horizontalHeader().setMinimumSectionSize(180) - self.setSelectionBehavior( - QAbstractItemView.SelectionBehavior.SelectRows) + self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) def upsert_task(self, task: FileTranscriptionTask): task_row_index = self.task_row_index(task.id) @@ -45,21 +44,19 @@ class TranscriptionTasksTableWidget(QTableWidget): row_index = self.rowCount() - 1 task_id_widget_item = QTableWidgetItem(str(task.id)) - self.setItem(row_index, self.Column.TASK_ID.value, - task_id_widget_item) + self.setItem(row_index, self.Column.TASK_ID.value, task_id_widget_item) - file_name_widget_item = QTableWidgetItem( - os.path.basename(task.file_path)) + file_name_widget_item = QTableWidgetItem(os.path.basename(task.file_path)) file_name_widget_item.setFlags( - file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable) - self.setItem(row_index, self.Column.FILE_NAME.value, - file_name_widget_item) + file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable + ) + self.setItem(row_index, self.Column.FILE_NAME.value, file_name_widget_item) status_widget_item = QTableWidgetItem(self.get_status_text(task)) status_widget_item.setFlags( - status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable) - self.setItem(row_index, self.Column.STATUS.value, - status_widget_item) + status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable + ) + self.setItem(row_index, self.Column.STATUS.value, status_widget_item) else: status_widget = self.item(task_row_index, self.Column.STATUS.value) status_widget.setText(self.get_status_text(task)) @@ -67,31 +64,30 @@ class TranscriptionTasksTableWidget(QTableWidget): @staticmethod def format_timedelta(delta: datetime.timedelta): mm, ss = divmod(delta.seconds, 60) - result = f'{ss}s' + result = f"{ss}s" if mm == 0: return result hh, mm = divmod(mm, 60) - result = f'{mm}m {result}' + result = f"{mm}m {result}" if hh == 0: return result - return f'{hh}h {result}' + return f"{hh}h {result}" @staticmethod def get_status_text(task: FileTranscriptionTask): if task.status == FileTranscriptionTask.Status.IN_PROGRESS: - return ( - f'{_("In Progress")} ({task.fraction_completed :.0%})') + return f'{_("In Progress")} ({task.fraction_completed :.0%})' elif task.status == FileTranscriptionTask.Status.COMPLETED: - status = _('Completed') + status = _("Completed") if task.started_at is not None and task.completed_at is not None: status += f" ({TranscriptionTasksTableWidget.format_timedelta(task.completed_at - task.started_at)})" return status elif task.status == FileTranscriptionTask.Status.FAILED: return f'{_("Failed")} ({task.error})' elif task.status == FileTranscriptionTask.Status.CANCELED: - return _('Canceled') + return _("Canceled") elif task.status == FileTranscriptionTask.Status.QUEUED: - return _('Queued') + return _("Queued") def clear_task(self, task_id: int): task_row_index = self.task_row_index(task_id) @@ -99,15 +95,20 @@ class TranscriptionTasksTableWidget(QTableWidget): self.removeRow(task_row_index) def task_row_index(self, task_id: int) -> int | None: - table_items_matching_task_id = [item for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly) if - item.column() == self.Column.TASK_ID.value] + table_items_matching_task_id = [ + item + for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly) + if item.column() == self.Column.TASK_ID.value + ] if len(table_items_matching_task_id) == 0: return None return table_items_matching_task_id[0].row() @staticmethod def find_task_id(index: QModelIndex): - sibling_index = index.siblingAtColumn(TranscriptionTasksTableWidget.Column.TASK_ID.value).data() + sibling_index = index.siblingAtColumn( + TranscriptionTasksTableWidget.Column.TASK_ID.value + ).data() return int(sibling_index) if sibling_index is not None else None def keyPressEvent(self, event: QtGui.QKeyEvent) -> None: diff --git a/buzz/widgets/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer_widget.py index 41cae3ac..4c2008bc 100644 --- a/buzz/widgets/transcription_viewer_widget.py +++ b/buzz/widgets/transcription_viewer_widget.py @@ -3,20 +3,32 @@ from typing import List, Optional from PyQt6.QtCore import Qt, pyqtSignal from PyQt6.QtGui import QUndoCommand, QUndoStack, QKeySequence, QAction -from PyQt6.QtWidgets import QWidget, QHBoxLayout, QMenu, QPushButton, QVBoxLayout, \ - QFileDialog +from PyQt6.QtWidgets import ( + QWidget, + QHBoxLayout, + QMenu, + QPushButton, + QVBoxLayout, + QFileDialog, +) from buzz.action import Action from buzz.assets import get_asset_path from buzz.locale import _ from buzz.paths import file_path_as_title -from buzz.transcriber import FileTranscriptionTask, Segment, OutputFormat, \ - get_default_output_file_path, write_output +from buzz.transcriber import ( + FileTranscriptionTask, + Segment, + OutputFormat, + get_default_output_file_path, + write_output, +) from buzz.widgets.audio_player import AudioPlayer from buzz.widgets.icon import Icon from buzz.widgets.toolbar import ToolBar -from buzz.widgets.transcription_segments_editor_widget import \ - TranscriptionSegmentsEditorWidget +from buzz.widgets.transcription_segments_editor_widget import ( + TranscriptionSegmentsEditorWidget, +) class TranscriptionViewerWidget(QWidget): @@ -24,9 +36,14 @@ class TranscriptionViewerWidget(QWidget): task_changed = pyqtSignal() class ChangeSegmentTextCommand(QUndoCommand): - def __init__(self, table_widget: TranscriptionSegmentsEditorWidget, - segments: List[Segment], - segment_index: int, segment_text: str, task_changed: pyqtSignal): + def __init__( + self, + table_widget: TranscriptionSegmentsEditorWidget, + segments: List[Segment], + segment_index: int, + segment_text: str, + task_changed: pyqtSignal, + ): super().__init__() self.table_widget = table_widget @@ -52,10 +69,11 @@ class TranscriptionViewerWidget(QWidget): self.task_changed.emit() def __init__( - self, transcription_task: FileTranscriptionTask, - open_transcription_output=True, - parent: Optional['QWidget'] = None, - flags: Qt.WindowType = Qt.WindowType.Widget, + self, + transcription_task: FileTranscriptionTask, + open_transcription_output=True, + parent: Optional["QWidget"] = None, + flags: Qt.WindowType = Qt.WindowType.Widget, ) -> None: super().__init__(parent, flags) self.transcription_task = transcription_task @@ -71,20 +89,23 @@ class TranscriptionViewerWidget(QWidget): undo_action = self.undo_stack.createUndoAction(self, _("Undo")) undo_action.setShortcuts(QKeySequence.StandardKey.Undo) undo_action.setIcon( - Icon(get_asset_path('assets/undo_FILL0_wght700_GRAD0_opsz48.svg'), self)) + Icon(get_asset_path("assets/undo_FILL0_wght700_GRAD0_opsz48.svg"), self) + ) undo_action.setToolTip(Action.get_tooltip(undo_action)) redo_action = self.undo_stack.createRedoAction(self, _("Redo")) redo_action.setShortcuts(QKeySequence.StandardKey.Redo) redo_action.setIcon( - Icon(get_asset_path('assets/redo_FILL0_wght700_GRAD0_opsz48.svg'), self)) + Icon(get_asset_path("assets/redo_FILL0_wght700_GRAD0_opsz48.svg"), self) + ) redo_action.setToolTip(Action.get_tooltip(redo_action)) toolbar = ToolBar() toolbar.addActions([undo_action, redo_action]) self.table_widget = TranscriptionSegmentsEditorWidget( - segments=transcription_task.segments, parent=self) + segments=transcription_task.segments, parent=self + ) self.table_widget.segment_text_changed.connect(self.on_segment_text_changed) self.table_widget.segment_index_selected.connect(self.on_segment_index_selected) @@ -96,14 +117,16 @@ class TranscriptionViewerWidget(QWidget): buttons_layout.addStretch() export_button_menu = QMenu() - actions = [QAction(text=output_format.value.upper(), parent=self) - for output_format in OutputFormat] + actions = [ + QAction(text=output_format.value.upper(), parent=self) + for output_format in OutputFormat + ] export_button_menu.addActions(actions) export_button_menu.triggered.connect(self.on_menu_triggered) export_button = QPushButton(self) - export_button.setText(_('Export')) + export_button.setText(_("Export")) export_button.setMenu(export_button_menu) buttons_layout.addWidget(export_button) @@ -120,11 +143,14 @@ class TranscriptionViewerWidget(QWidget): def on_segment_text_changed(self, event: tuple): segment_index, segment_text = event self.undo_stack.push( - self.ChangeSegmentTextCommand(table_widget=self.table_widget, - segments=self.transcription_task.segments, - segment_index=segment_index, - segment_text=segment_text, - task_changed=self.task_changed)) + self.ChangeSegmentTextCommand( + table_widget=self.table_widget, + segments=self.transcription_task.segments, + segment_index=segment_index, + segment_text=segment_text, + task_changed=self.task_changed, + ) + ) def on_segment_index_selected(self, index: int): selected_segment = self.transcription_task.segments[index] @@ -134,15 +160,22 @@ class TranscriptionViewerWidget(QWidget): def on_menu_triggered(self, action: QAction): output_format = OutputFormat[action.text()] - default_path = get_default_output_file_path(task=self.transcription_task, - output_format=output_format) + default_path = get_default_output_file_path( + task=self.transcription_task, output_format=output_format + ) - (output_file_path, nil) = QFileDialog.getSaveFileName(self, _('Save File'), - default_path, - _('Text files') + f' (*.{output_format.value})') + (output_file_path, nil) = QFileDialog.getSaveFileName( + self, + _("Save File"), + default_path, + _("Text files") + f" (*.{output_format.value})", + ) - if output_file_path == '': + if output_file_path == "": return - write_output(path=output_file_path, segments=self.transcription_task.segments, - output_format=output_format) + write_output( + path=output_file_path, + segments=self.transcription_task.segments, + output_format=output_format, + ) diff --git a/poetry.lock b/poetry.lock index dd7f0fb6..fc4b74da 100644 --- a/poetry.lock +++ b/poetry.lock @@ -266,6 +266,54 @@ files = [ {file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"}, ] +[[package]] +name = "black" +version = "23.7.0" +description = "The uncompromising code formatter." +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, + {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, + {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, + {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, + {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, + {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, + {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, + {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, + {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, + {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, + {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, +] + +[package.dependencies] +aiohttp = {version = ">=3.7.4", optional = true, markers = "extra == \"d\""} +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "certifi" version = "2023.5.7" @@ -452,6 +500,21 @@ files = [ {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"}, ] +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "cmake" version = "3.26.4" @@ -1534,6 +1597,18 @@ files = [ {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] +[[package]] +name = "pathspec" +version = "0.11.2" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, + {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, +] + [[package]] name = "pefile" version = "2023.2.7" @@ -2669,4 +2744,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.9.13,<3.11" -content-hash = "ceb6ce6c7083882f1499bd36f5e98f6aa1e0a872d8268ccbda91d67ee81fdd1e" +content-hash = "fe7fae59602bd0ecdceafbfe274f6f36f0cb489b67bfc7d4bfae4998dbbe672a" diff --git a/pyproject.toml b/pyproject.toml index 4449ec61..ad02d5dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ pytest-xvfb = "^2.0.0" pylint = "^2.15.5" pre-commit = "^2.20.0" pytest-benchmark = "^4.0.0" +black = {extras = ["d"], version = "^23.7.0"} [tool.poetry.group.build.dependencies] ctypesgen = "^1.1.1" diff --git a/setup.py b/setup.py index f9306276..f4663837 100644 --- a/setup.py +++ b/setup.py @@ -1,43 +1,43 @@ # -*- coding: utf-8 -*- from setuptools import setup -packages = \ -['buzz', 'buzz.settings', 'buzz.store', 'buzz.widgets'] +packages = ["buzz", "buzz.settings", "buzz.store", "buzz.widgets"] -package_data = \ -{'': ['*']} +package_data = {"": ["*"]} -install_requires = \ -['PyQt6==6.4.0', - 'appdirs>=1.4.4,<2.0.0', - 'faster-whisper>=0.4.1,<0.5.0', - 'ffmpeg-python>=0.2.0,<0.3.0', - 'humanize>=4.4.0,<5.0.0', - 'keyring>=23.13.1,<24.0.0', - 'openai-whisper==v20230124', - 'openai>=0.27.1,<0.28.0', - 'platformdirs==3.5.3', - 'sounddevice>=0.4.5,<0.5.0', - 'stable-ts==1.0.2', - 'torch==1.12.1', - 'transformers>=4.24.0,<4.25.0'] +install_requires = [ + "PyQt6==6.4.0", + "appdirs>=1.4.4,<2.0.0", + "faster-whisper>=0.4.1,<0.5.0", + "ffmpeg-python>=0.2.0,<0.3.0", + "humanize>=4.4.0,<5.0.0", + "keyring>=23.13.1,<24.0.0", + "openai-whisper==v20230124", + "openai>=0.27.1,<0.28.0", + "platformdirs==3.5.3", + "sounddevice>=0.4.5,<0.5.0", + "stable-ts==1.0.2", + "torch==1.12.1", + "transformers>=4.24.0,<4.25.0", +] setup_kwargs = { - 'name': 'buzz', - 'version': '0.8.3', - 'description': '', - 'long_description': '# Buzz\n\nTranscribe and translate audio offline on your personal computer. Powered by\nOpenAI\'s [Whisper](https://github.com/openai/whisper).\n\n![MIT License](https://img.shields.io/badge/license-MIT-green)\n[![CI](https://github.com/chidiwilliams/buzz/actions/workflows/ci.yml/badge.svg)](https://github.com/chidiwilliams/buzz/actions/workflows/ci.yml)\n[![codecov](https://codecov.io/github/chidiwilliams/buzz/branch/main/graph/badge.svg?token=YJSB8S2VEP)](https://codecov.io/github/chidiwilliams/buzz)\n![GitHub release (latest by date)](https://img.shields.io/github/v/release/chidiwilliams/buzz)\n[![Github all releases](https://img.shields.io/github/downloads/chidiwilliams/buzz/total.svg)](https://GitHub.com/chidiwilliams/buzz/releases/)\n\n
\n

Buzz is better on the App Store. Get a Mac-native version of Buzz with a cleaner look, audio playback, drag-and-drop import, transcript editing, search, and much more.

\nDownload on the Mac App Store\n
\n\n![Buzz](./assets/buzz-banner.jpg)\n\n## Features\n\n- Import audio and video files and export transcripts to TXT, SRT, and\n VTT ([Demo](https://www.loom.com/share/cf263b099ac3481082bb56d19b7c87fe))\n- Transcription and translation from your computer\'s microphones to text (Resource-intensive and may not be\n real-time, [Demo](https://www.loom.com/share/564b753eb4d44b55b985b8abd26b55f7))\n- Supports [Whisper](https://github.com/openai/whisper#available-models-and-languages),\n [Whisper.cpp](https://github.com/ggerganov/whisper.cpp), [Faster Whisper](https://github.com/guillaumekln/faster-whisper),\n [Whisper-compatible Hugging Face models](https://huggingface.co/models?other=whisper), and\n the [OpenAI Whisper API](https://platform.openai.com/docs/api-reference/introduction)\n- [Command-Line Interface](#command-line-interface)\n- Available on Mac, Windows, and Linux\n\n## Installation\n\nTo install Buzz, download the [latest version](https://github.com/chidiwilliams/buzz/releases/latest) for your operating\nsystem. Buzz is available on **Mac** (Intel), **Windows**, and **Linux**. (For Apple Silicon, please see the [App Store version](https://apps.apple.com/us/app/buzz-captions/id6446018936?mt=12&itsct=apps_box_badge&itscg=30200).)\n\n### Mac (Intel, macOS 11.7 and later)\n\n- Install via [brew](https://brew.sh/):\n\n ```shell\n brew install --cask buzz\n ```\n\n Alternatively, download and run the `Buzz-x.y.z.dmg` file.\n\n### Windows (Windows 10 and later)\n\n- Download and run the `Buzz-x.y.z.exe` file.\n\n### Linux (Ubuntu 20.04 and later)\n\n- Install dependencies:\n\n ```shell\n sudo apt-get install libportaudio2\n ```\n\n- Download and extract the `Buzz-x.y.z-unix.tar.gz` file\n\n## How to use\n\n### File import\n\nTo import a file:\n\n- Click Import Media File on the File menu (or the \'+\' icon on the toolbar, or **Command/Ctrl + O**).\n- Choose an audio or video file.\n- Select a task, language, and the model settings.\n- Click Run.\n- When the transcription status shows \'Completed\', double-click on the row (or select the row and click the \'⤢\' icon) to\n open the transcription.\n\n| Field | Options | Default | Description |\n| ------------------ | ------------------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| Export As | "TXT", "SRT", "VTT" | "TXT" | Export file format |\n| Word-Level Timings | Off / On | Off | If checked, the transcription will generate a separate subtitle line for each word in the audio. Enabled only when "Export As" is set to "SRT" or "VTT". |\n\n(See the [Live Recording section](#live-recording) for more information about the task, language, and quality settings.)\n\n[![Media File Import on Buzz](https://cdn.loom.com/sessions/thumbnails/cf263b099ac3481082bb56d19b7c87fe-with-play.gif)](https://www.loom.com/share/cf263b099ac3481082bb56d19b7c87fe "Media File Import on Buzz")\n\n### Live Recording\n\nTo start a live recording:\n\n- Select a recording task, language, quality, and microphone.\n- Click Record.\n\n> **Note:** Transcribing audio using the default Whisper model is resource-intensive. Consider using the Whisper.cpp\n> Tiny model to get real-time performance.\n\n| Field | Options | Default | Description |\n| ---------- | ---------------------------------------------------------------------------------------------------------------------------------------- | --------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| Task | "Transcribe", "Translate" | "Transcribe" | "Transcribe" converts the input audio into text in the selected language, while "Translate" converts it into text in English. |\n| Language | See [Whisper\'s documentation](https://github.com/openai/whisper#available-models-and-languages) for the full list of supported languages | "Detect Language" | "Detect Language" will try to detect the spoken language in the audio based on the first few seconds. However, selecting a language is recommended (if known) as it will improve transcription quality in many cases. |\n| Quality | "Very Low", "Low", "Medium", "High" | "Very Low" | The transcription quality determines the Whisper model used for transcription. "Very Low" uses the "tiny" model; "Low" uses the "base" model; "Medium" uses the "small" model; and "High" uses the "medium" model. The larger models produce higher-quality transcriptions, but require more system resources. See [Whisper\'s documentation](https://github.com/openai/whisper#available-models-and-languages) for more information about the models. |\n| Microphone | [Available system microphones] | [Default system microphone] | Microphone for recording input audio. |\n\n[![Live Recording on Buzz](https://cdn.loom.com/sessions/thumbnails/564b753eb4d44b55b985b8abd26b55f7-with-play.gif)](https://www.loom.com/share/564b753eb4d44b55b985b8abd26b55f7 "Live Recording on Buzz")\n\n### Record audio playing from computer\n\nTo record audio playing from an application on your computer, you may install an audio loopback driver (a program that\nlets you create virtual audio devices). The rest of this guide will\nuse [BlackHole](https://github.com/ExistentialAudio/BlackHole) on Mac, but you can use other alternatives for your\noperating system (\nsee [LoopBeAudio](https://nerds.de/en/loopbeaudio.html), [LoopBack](https://rogueamoeba.com/loopback/),\nand [Virtual Audio Cable](https://vac.muzychenko.net/en/)).\n\n1. Install [BlackHole via Homebrew](https://github.com/ExistentialAudio/BlackHole#option-2-install-via-homebrew)\n\n ```shell\n brew install blackhole-2ch\n ```\n\n2. Open Audio MIDI Setup from Spotlight or from `/Applications/Utilities/Audio Midi Setup.app`.\n\n ![Open Audio MIDI Setup from Spotlight](https://existential.audio/howto/img/spotlight.png)\n\n3. Click the \'+\' icon at the lower left corner and select \'Create Multi-Output Device\'.\n\n ![Create multi-output device](https://existential.audio/howto/img/createmulti-output.png)\n\n4. Add your default speaker and BlackHole to the multi-output device.\n\n ![Screenshot of multi-output device](https://existential.audio/howto/img/multi-output.png)\n\n5. Select this multi-output device as your speaker (application or system-wide) to play audio into BlackHole.\n\n6. Open Buzz, select BlackHole as your microphone, and record as before to see transcriptions from the audio playing\n through BlackHole.\n\n## Command-Line Interface\n\n### `add`\n\nStart a new transcription task\n\nExamples:\n\n```shell\n# Translate two MP3 files from French to English using OpenAI Whisper API\nbuzz add --task translate --language fr --model-type openaiapi /Users/user/Downloads/1b3b03e4-8db5-ea2c-ace5-b71ff32e3304.mp3 /Users/user/Downloads/koaf9083k1lkpsfdi0.mp3\n\n# Transcribe an MP4 using Whisper.cpp "small" model and immediately export to SRT and VTT files\nbuzz add --task transcribe --model-type whispercpp --model-size small --prompt "My initial prompt" --srt --vtt /Users/user/Downloads/buzz/1b3b03e4-8db5-ea2c-ace5-b71ff32e3304.mp4\n```\n\nRun `buzz add --help` to see all available options.\n\n## Build\n\nTo build/run Buzz locally from source, first install the requirements:\n\n1. [Poetry](https://python-poetry.org/docs/#installing-with-the-official-installer)\n\nThen:\n\n1. Clone the repository\n\n ```shell\n git clone --recurse-submodules https://github.com/chidiwilliams/buzz\n ```\n\n2. Install the project dependencies.\n\n ```shell\n poetry install\n ```\n\n3. (Optional) To use Whisper.cpp inference, run:\n\n ```shell\n make buzz/whisper_cpp.py\n ```\n\n4. (Optional) To compile the translations, run:\n\n ```shell\n make translation_mo\n ```\n\n5. Finally, run the app with:\n\n ```shell\n poetry run python main.py\n ```\n\n Or build with:\n\n ```shell\n poetry run pyinstaller --noconfirm Buzz.spec\n ```\n\n## FAQ\n\n1. **Where are the models stored?**\n\n The Whisper models are stored in `~/.cache/whisper`. The Whisper.cpp models are stored in `~/Library/Caches/Buzz` (\n Mac OS), `~/.cache/Buzz` (Unix), or `C:\\Users\\\\AppData\\Local\\Buzz\\Buzz\\Cache` (Windows). The Hugging Face\n models are stored in `~/.cache/huggingface/hub`.\n\n2. **What can I try if the transcription runs too slowly?**\n\n Try using a lower Whisper model size or using a Whisper.cpp model.\n\n## Credits\n\n- SVG Icons: [Google Fonts Material Symbols](https://fonts.google.com/icons)\n', - 'author': 'Chidi Williams', - 'author_email': 'williamschidi1@gmail.com', - 'maintainer': 'None', - 'maintainer_email': 'None', - 'url': 'None', - 'packages': packages, - 'package_data': package_data, - 'install_requires': install_requires, - 'python_requires': '>=3.9.13,<3.11', + "name": "buzz", + "version": "0.8.3", + "description": "", + "long_description": '# Buzz\n\nTranscribe and translate audio offline on your personal computer. Powered by\nOpenAI\'s [Whisper](https://github.com/openai/whisper).\n\n![MIT License](https://img.shields.io/badge/license-MIT-green)\n[![CI](https://github.com/chidiwilliams/buzz/actions/workflows/ci.yml/badge.svg)](https://github.com/chidiwilliams/buzz/actions/workflows/ci.yml)\n[![codecov](https://codecov.io/github/chidiwilliams/buzz/branch/main/graph/badge.svg?token=YJSB8S2VEP)](https://codecov.io/github/chidiwilliams/buzz)\n![GitHub release (latest by date)](https://img.shields.io/github/v/release/chidiwilliams/buzz)\n[![Github all releases](https://img.shields.io/github/downloads/chidiwilliams/buzz/total.svg)](https://GitHub.com/chidiwilliams/buzz/releases/)\n\n
\n

Buzz is better on the App Store. Get a Mac-native version of Buzz with a cleaner look, audio playback, drag-and-drop import, transcript editing, search, and much more.

\nDownload on the Mac App Store\n
\n\n![Buzz](./assets/buzz-banner.jpg)\n\n## Features\n\n- Import audio and video files and export transcripts to TXT, SRT, and\n VTT ([Demo](https://www.loom.com/share/cf263b099ac3481082bb56d19b7c87fe))\n- Transcription and translation from your computer\'s microphones to text (Resource-intensive and may not be\n real-time, [Demo](https://www.loom.com/share/564b753eb4d44b55b985b8abd26b55f7))\n- Supports [Whisper](https://github.com/openai/whisper#available-models-and-languages),\n [Whisper.cpp](https://github.com/ggerganov/whisper.cpp), [Faster Whisper](https://github.com/guillaumekln/faster-whisper),\n [Whisper-compatible Hugging Face models](https://huggingface.co/models?other=whisper), and\n the [OpenAI Whisper API](https://platform.openai.com/docs/api-reference/introduction)\n- [Command-Line Interface](#command-line-interface)\n- Available on Mac, Windows, and Linux\n\n## Installation\n\nTo install Buzz, download the [latest version](https://github.com/chidiwilliams/buzz/releases/latest) for your operating\nsystem. Buzz is available on **Mac** (Intel), **Windows**, and **Linux**. (For Apple Silicon, please see the [App Store version](https://apps.apple.com/us/app/buzz-captions/id6446018936?mt=12&itsct=apps_box_badge&itscg=30200).)\n\n### Mac (Intel, macOS 11.7 and later)\n\n- Install via [brew](https://brew.sh/):\n\n ```shell\n brew install --cask buzz\n ```\n\n Alternatively, download and run the `Buzz-x.y.z.dmg` file.\n\n### Windows (Windows 10 and later)\n\n- Download and run the `Buzz-x.y.z.exe` file.\n\n### Linux (Ubuntu 20.04 and later)\n\n- Install dependencies:\n\n ```shell\n sudo apt-get install libportaudio2\n ```\n\n- Download and extract the `Buzz-x.y.z-unix.tar.gz` file\n\n## How to use\n\n### File import\n\nTo import a file:\n\n- Click Import Media File on the File menu (or the \'+\' icon on the toolbar, or **Command/Ctrl + O**).\n- Choose an audio or video file.\n- Select a task, language, and the model settings.\n- Click Run.\n- When the transcription status shows \'Completed\', double-click on the row (or select the row and click the \'⤢\' icon) to\n open the transcription.\n\n| Field | Options | Default | Description |\n| ------------------ | ------------------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| Export As | "TXT", "SRT", "VTT" | "TXT" | Export file format |\n| Word-Level Timings | Off / On | Off | If checked, the transcription will generate a separate subtitle line for each word in the audio. Enabled only when "Export As" is set to "SRT" or "VTT". |\n\n(See the [Live Recording section](#live-recording) for more information about the task, language, and quality settings.)\n\n[![Media File Import on Buzz](https://cdn.loom.com/sessions/thumbnails/cf263b099ac3481082bb56d19b7c87fe-with-play.gif)](https://www.loom.com/share/cf263b099ac3481082bb56d19b7c87fe "Media File Import on Buzz")\n\n### Live Recording\n\nTo start a live recording:\n\n- Select a recording task, language, quality, and microphone.\n- Click Record.\n\n> **Note:** Transcribing audio using the default Whisper model is resource-intensive. Consider using the Whisper.cpp\n> Tiny model to get real-time performance.\n\n| Field | Options | Default | Description |\n| ---------- | ---------------------------------------------------------------------------------------------------------------------------------------- | --------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| Task | "Transcribe", "Translate" | "Transcribe" | "Transcribe" converts the input audio into text in the selected language, while "Translate" converts it into text in English. |\n| Language | See [Whisper\'s documentation](https://github.com/openai/whisper#available-models-and-languages) for the full list of supported languages | "Detect Language" | "Detect Language" will try to detect the spoken language in the audio based on the first few seconds. However, selecting a language is recommended (if known) as it will improve transcription quality in many cases. |\n| Quality | "Very Low", "Low", "Medium", "High" | "Very Low" | The transcription quality determines the Whisper model used for transcription. "Very Low" uses the "tiny" model; "Low" uses the "base" model; "Medium" uses the "small" model; and "High" uses the "medium" model. The larger models produce higher-quality transcriptions, but require more system resources. See [Whisper\'s documentation](https://github.com/openai/whisper#available-models-and-languages) for more information about the models. |\n| Microphone | [Available system microphones] | [Default system microphone] | Microphone for recording input audio. |\n\n[![Live Recording on Buzz](https://cdn.loom.com/sessions/thumbnails/564b753eb4d44b55b985b8abd26b55f7-with-play.gif)](https://www.loom.com/share/564b753eb4d44b55b985b8abd26b55f7 "Live Recording on Buzz")\n\n### Record audio playing from computer\n\nTo record audio playing from an application on your computer, you may install an audio loopback driver (a program that\nlets you create virtual audio devices). The rest of this guide will\nuse [BlackHole](https://github.com/ExistentialAudio/BlackHole) on Mac, but you can use other alternatives for your\noperating system (\nsee [LoopBeAudio](https://nerds.de/en/loopbeaudio.html), [LoopBack](https://rogueamoeba.com/loopback/),\nand [Virtual Audio Cable](https://vac.muzychenko.net/en/)).\n\n1. Install [BlackHole via Homebrew](https://github.com/ExistentialAudio/BlackHole#option-2-install-via-homebrew)\n\n ```shell\n brew install blackhole-2ch\n ```\n\n2. Open Audio MIDI Setup from Spotlight or from `/Applications/Utilities/Audio Midi Setup.app`.\n\n ![Open Audio MIDI Setup from Spotlight](https://existential.audio/howto/img/spotlight.png)\n\n3. Click the \'+\' icon at the lower left corner and select \'Create Multi-Output Device\'.\n\n ![Create multi-output device](https://existential.audio/howto/img/createmulti-output.png)\n\n4. Add your default speaker and BlackHole to the multi-output device.\n\n ![Screenshot of multi-output device](https://existential.audio/howto/img/multi-output.png)\n\n5. Select this multi-output device as your speaker (application or system-wide) to play audio into BlackHole.\n\n6. Open Buzz, select BlackHole as your microphone, and record as before to see transcriptions from the audio playing\n through BlackHole.\n\n## Command-Line Interface\n\n### `add`\n\nStart a new transcription task\n\nExamples:\n\n```shell\n# Translate two MP3 files from French to English using OpenAI Whisper API\nbuzz add --task translate --language fr --model-type openaiapi /Users/user/Downloads/1b3b03e4-8db5-ea2c-ace5-b71ff32e3304.mp3 /Users/user/Downloads/koaf9083k1lkpsfdi0.mp3\n\n# Transcribe an MP4 using Whisper.cpp "small" model and immediately export to SRT and VTT files\nbuzz add --task transcribe --model-type whispercpp --model-size small --prompt "My initial prompt" --srt --vtt /Users/user/Downloads/buzz/1b3b03e4-8db5-ea2c-ace5-b71ff32e3304.mp4\n```\n\nRun `buzz add --help` to see all available options.\n\n## Build\n\nTo build/run Buzz locally from source, first install the requirements:\n\n1. [Poetry](https://python-poetry.org/docs/#installing-with-the-official-installer)\n\nThen:\n\n1. Clone the repository\n\n ```shell\n git clone --recurse-submodules https://github.com/chidiwilliams/buzz\n ```\n\n2. Install the project dependencies.\n\n ```shell\n poetry install\n ```\n\n3. (Optional) To use Whisper.cpp inference, run:\n\n ```shell\n make buzz/whisper_cpp.py\n ```\n\n4. (Optional) To compile the translations, run:\n\n ```shell\n make translation_mo\n ```\n\n5. Finally, run the app with:\n\n ```shell\n poetry run python main.py\n ```\n\n Or build with:\n\n ```shell\n poetry run pyinstaller --noconfirm Buzz.spec\n ```\n\n## FAQ\n\n1. **Where are the models stored?**\n\n The Whisper models are stored in `~/.cache/whisper`. The Whisper.cpp models are stored in `~/Library/Caches/Buzz` (\n Mac OS), `~/.cache/Buzz` (Unix), or `C:\\Users\\\\AppData\\Local\\Buzz\\Buzz\\Cache` (Windows). The Hugging Face\n models are stored in `~/.cache/huggingface/hub`.\n\n2. **What can I try if the transcription runs too slowly?**\n\n Try using a lower Whisper model size or using a Whisper.cpp model.\n\n## Credits\n\n- SVG Icons: [Google Fonts Material Symbols](https://fonts.google.com/icons)\n', + "author": "Chidi Williams", + "author_email": "williamschidi1@gmail.com", + "maintainer": "None", + "maintainer_email": "None", + "url": "None", + "packages": packages, + "package_data": package_data, + "install_requires": install_requires, + "python_requires": ">=3.9.13,<3.11", } from build import * + build(setup_kwargs) setup(**setup_kwargs) diff --git a/tests/cache_test.py b/tests/cache_test.py index 520fbf11..f74c3b80 100644 --- a/tests/cache_test.py +++ b/tests/cache_test.py @@ -1,16 +1,31 @@ from buzz.cache import TasksCache -from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, - TranscriptionOptions) +from buzz.transcriber import ( + FileTranscriptionOptions, + FileTranscriptionTask, + TranscriptionOptions, +) class TestTasksCache: def test_should_save_and_load(self, tmp_path): cache = TasksCache(cache_dir=str(tmp_path)) - tasks = [FileTranscriptionTask(file_path='1.mp3', transcription_options=TranscriptionOptions(), - file_transcription_options=FileTranscriptionOptions(file_paths=['1.mp3']), - model_path=''), - FileTranscriptionTask(file_path='2.mp3', transcription_options=TranscriptionOptions(), - file_transcription_options=FileTranscriptionOptions(file_paths=['2.mp3']), - model_path='')] + tasks = [ + FileTranscriptionTask( + file_path="1.mp3", + transcription_options=TranscriptionOptions(), + file_transcription_options=FileTranscriptionOptions( + file_paths=["1.mp3"] + ), + model_path="", + ), + FileTranscriptionTask( + file_path="2.mp3", + transcription_options=TranscriptionOptions(), + file_transcription_options=FileTranscriptionOptions( + file_paths=["2.mp3"] + ), + model_path="", + ), + ] cache.save(tasks) assert cache.load() == tasks diff --git a/tests/gui_test.py b/tests/gui_test.py index 92362cec..3d0ae1bd 100644 --- a/tests/gui_test.py +++ b/tests/gui_test.py @@ -8,91 +8,100 @@ import pytest import sounddevice from PyQt6.QtCore import QSize, Qt from PyQt6.QtGui import QValidator, QKeyEvent -from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget, QApplication, QMessageBox +from PyQt6.QtWidgets import ( + QPushButton, + QToolBar, + QTableWidget, + QApplication, + QMessageBox, +) from _pytest.fixtures import SubRequest from pytestqt.qtbot import QtBot from buzz.__version__ import VERSION from buzz.cache import TasksCache -from buzz.gui import (AudioDevicesComboBox, MainWindow, - RecordingTranscriberWidget) +from buzz.gui import AudioDevicesComboBox, MainWindow, RecordingTranscriberWidget from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget -from buzz.widgets.transcriber.hugging_face_search_line_edit import \ - HuggingFaceSearchLineEdit +from buzz.widgets.transcriber.hugging_face_search_line_edit import ( + HuggingFaceSearchLineEdit, +) from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox from buzz.widgets.transcriber.temperature_validator import TemperatureValidator from buzz.widgets.about_dialog import AboutDialog from buzz.model_loader import ModelType from buzz.settings.settings import Settings -from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, - TranscriptionOptions) -from buzz.widgets.transcriber.transcription_options_group_box import \ - TranscriptionOptionsGroupBox +from buzz.transcriber import ( + FileTranscriptionOptions, + FileTranscriptionTask, + TranscriptionOptions, +) +from buzz.widgets.transcriber.transcription_options_group_box import ( + TranscriptionOptionsGroupBox, +) from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget from tests.mock_sounddevice import MockInputStream, mock_query_devices from .mock_qt import MockNetworkAccessManager, MockNetworkReply -if platform.system() == 'Linux': - multiprocessing.set_start_method('spawn') +if platform.system() == "Linux": + multiprocessing.set_start_method("spawn") -@pytest.fixture(scope='module', autouse=True) +@pytest.fixture(scope="module", autouse=True) def audio_setup(): - with patch('sounddevice.query_devices') as query_devices_mock, \ - patch('sounddevice.InputStream', side_effect=MockInputStream), \ - patch('sounddevice.check_input_settings'): + with patch("sounddevice.query_devices") as query_devices_mock, patch( + "sounddevice.InputStream", side_effect=MockInputStream + ), patch("sounddevice.check_input_settings"): query_devices_mock.return_value = mock_query_devices sounddevice.default.device = 3, 4 yield class TestLanguagesComboBox: - def test_should_show_sorted_whisper_languages(self, qtbot): - languages_combox_box = LanguagesComboBox('en') + languages_combox_box = LanguagesComboBox("en") qtbot.add_widget(languages_combox_box) - assert languages_combox_box.itemText(0) == 'Detect Language' - assert languages_combox_box.itemText(10) == 'Belarusian' - assert languages_combox_box.itemText(20) == 'Dutch' - assert languages_combox_box.itemText(30) == 'Gujarati' - assert languages_combox_box.itemText(40) == 'Japanese' - assert languages_combox_box.itemText(50) == 'Lithuanian' + assert languages_combox_box.itemText(0) == "Detect Language" + assert languages_combox_box.itemText(10) == "Belarusian" + assert languages_combox_box.itemText(20) == "Dutch" + assert languages_combox_box.itemText(30) == "Gujarati" + assert languages_combox_box.itemText(40) == "Japanese" + assert languages_combox_box.itemText(50) == "Lithuanian" def test_should_select_en_as_default_language(self, qtbot): - languages_combox_box = LanguagesComboBox('en') + languages_combox_box = LanguagesComboBox("en") qtbot.add_widget(languages_combox_box) - assert languages_combox_box.currentText() == 'English' + assert languages_combox_box.currentText() == "English" def test_should_select_detect_language_as_default(self, qtbot): languages_combo_box = LanguagesComboBox(None) qtbot.add_widget(languages_combo_box) - assert languages_combo_box.currentText() == 'Detect Language' + assert languages_combo_box.currentText() == "Detect Language" class TestAudioDevicesComboBox: def test_get_devices(self): audio_devices_combo_box = AudioDevicesComboBox() - assert audio_devices_combo_box.itemText(0) == 'Background Music' - assert audio_devices_combo_box.itemText(1) == 'Background Music (UI Sounds)' - assert audio_devices_combo_box.itemText(2) == 'BlackHole 2ch' - assert audio_devices_combo_box.itemText(3) == 'MacBook Pro Microphone' - assert audio_devices_combo_box.itemText(4) == 'Null Audio Device' + assert audio_devices_combo_box.itemText(0) == "Background Music" + assert audio_devices_combo_box.itemText(1) == "Background Music (UI Sounds)" + assert audio_devices_combo_box.itemText(2) == "BlackHole 2ch" + assert audio_devices_combo_box.itemText(3) == "MacBook Pro Microphone" + assert audio_devices_combo_box.itemText(4) == "Null Audio Device" - assert audio_devices_combo_box.currentText() == 'MacBook Pro Microphone' + assert audio_devices_combo_box.currentText() == "MacBook Pro Microphone" def test_select_default_mic_when_no_default(self): sounddevice.default.device = -1, 1 audio_devices_combo_box = AudioDevicesComboBox() - assert audio_devices_combo_box.currentText() == 'Background Music' + assert audio_devices_combo_box.currentText() == "Background Music" @pytest.fixture() def tasks_cache(tmp_path, request: SubRequest): cache = TasksCache(cache_dir=str(tmp_path)) - if hasattr(request, 'param'): + if hasattr(request, "param"): tasks: List[FileTranscriptionTask] = request.param cache.save(tasks) yield cache @@ -100,28 +109,40 @@ def tasks_cache(tmp_path, request: SubRequest): def get_test_asset(filename: str): - return os.path.join(os.path.dirname(__file__), '../testdata/', filename) + return os.path.join(os.path.dirname(__file__), "../testdata/", filename) mock_tasks = [ - FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(), - file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='', - status=FileTranscriptionTask.Status.COMPLETED), - FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(), - file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='', - status=FileTranscriptionTask.Status.CANCELED), - FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(), - file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='', - status=FileTranscriptionTask.Status.FAILED, error='Error'), + FileTranscriptionTask( + file_path="", + transcription_options=TranscriptionOptions(), + file_transcription_options=FileTranscriptionOptions(file_paths=[]), + model_path="", + status=FileTranscriptionTask.Status.COMPLETED, + ), + FileTranscriptionTask( + file_path="", + transcription_options=TranscriptionOptions(), + file_transcription_options=FileTranscriptionOptions(file_paths=[]), + model_path="", + status=FileTranscriptionTask.Status.CANCELED, + ), + FileTranscriptionTask( + file_path="", + transcription_options=TranscriptionOptions(), + file_transcription_options=FileTranscriptionOptions(file_paths=[]), + model_path="", + status=FileTranscriptionTask.Status.FAILED, + error="Error", + ), ] class TestMainWindow: - def test_should_set_window_title_and_icon(self, qtbot): window = MainWindow() qtbot.add_widget(window) - assert window.windowTitle() == 'Buzz' + assert window.windowTitle() == "Buzz" assert window.windowIcon().pixmap(QSize(64, 64)).isNull() is False window.close() @@ -132,13 +153,18 @@ class TestMainWindow: self._start_new_transcription(window) - open_transcript_action = self._get_toolbar_action(window, 'Open Transcript') + open_transcript_action = self._get_toolbar_action(window, "Open Transcript") assert open_transcript_action.isEnabled() is False table_widget: QTableWidget = window.findChild(QTableWidget) - qtbot.wait_until(self._assert_task_status(table_widget, 0, 'Completed'), timeout=2 * 60 * 1000) + qtbot.wait_until( + self._assert_task_status(table_widget, 0, "Completed"), + timeout=2 * 60 * 1000, + ) - table_widget.setCurrentIndex(table_widget.indexFromItem(table_widget.item(0, 1))) + table_widget.setCurrentIndex( + table_widget.indexFromItem(table_widget.item(0, 1)) + ) assert open_transcript_action.isEnabled() # @pytest.mark.skip(reason='Timing out or crashing') @@ -152,8 +178,8 @@ class TestMainWindow: def assert_task_in_progress(): assert table_widget.rowCount() > 0 - assert table_widget.item(0, 1).text() == 'whisper-french.mp3' - assert 'In Progress' in table_widget.item(0, 2).text() + assert table_widget.item(0, 1).text() == "whisper-french.mp3" + assert "In Progress" in table_widget.item(0, 2).text() qtbot.wait_until(assert_task_in_progress, timeout=2 * 60 * 1000) @@ -161,7 +187,9 @@ class TestMainWindow: table_widget.selectRow(0) window.toolbar.stop_transcription_action.trigger() - qtbot.wait_until(self._assert_task_status(table_widget, 0, 'Canceled'), timeout=60 * 1000) + qtbot.wait_until( + self._assert_task_status(table_widget, 0, "Canceled"), timeout=60 * 1000 + ) table_widget.selectRow(0) assert window.toolbar.stop_transcription_action.isEnabled() is False @@ -169,7 +197,7 @@ class TestMainWindow: window.close() - @pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True) + @pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True) def test_should_load_tasks_from_cache(self, qtbot, tasks_cache): window = MainWindow(tasks_cache=tasks_cache) qtbot.add_widget(window) @@ -177,43 +205,47 @@ class TestMainWindow: table_widget: QTableWidget = window.findChild(QTableWidget) assert table_widget.rowCount() == 3 - assert table_widget.item(0, 2).text() == 'Completed' + assert table_widget.item(0, 2).text() == "Completed" table_widget.selectRow(0) assert window.toolbar.open_transcript_action.isEnabled() - assert table_widget.item(1, 2).text() == 'Canceled' + assert table_widget.item(1, 2).text() == "Canceled" table_widget.selectRow(1) assert window.toolbar.open_transcript_action.isEnabled() is False - assert table_widget.item(2, 2).text() == 'Failed (Error)' + assert table_widget.item(2, 2).text() == "Failed (Error)" table_widget.selectRow(2) assert window.toolbar.open_transcript_action.isEnabled() is False window.close() - @pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True) + @pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True) def test_should_clear_history_with_rows_selected(self, qtbot, tasks_cache): window = MainWindow(tasks_cache=tasks_cache) table_widget: QTableWidget = window.findChild(QTableWidget) table_widget.selectAll() - with patch('PyQt6.QtWidgets.QMessageBox.question') as question_message_box_mock: + with patch("PyQt6.QtWidgets.QMessageBox.question") as question_message_box_mock: question_message_box_mock.return_value = QMessageBox.StandardButton.Yes window.toolbar.clear_history_action.trigger() assert table_widget.rowCount() == 0 window.close() - @pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True) - def test_should_have_clear_history_action_disabled_with_no_rows_selected(self, qtbot, tasks_cache): + @pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True) + def test_should_have_clear_history_action_disabled_with_no_rows_selected( + self, qtbot, tasks_cache + ): window = MainWindow(tasks_cache=tasks_cache) qtbot.add_widget(window) assert window.toolbar.clear_history_action.isEnabled() is False window.close() - @pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True) - def test_should_open_transcription_viewer_when_menu_action_is_clicked(self, qtbot, tasks_cache): + @pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True) + def test_should_open_transcription_viewer_when_menu_action_is_clicked( + self, qtbot, tasks_cache + ): window = MainWindow(tasks_cache=tasks_cache) qtbot.add_widget(window) @@ -228,23 +260,33 @@ class TestMainWindow: window.close() - @pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True) - def test_should_open_transcription_viewer_when_return_clicked(self, qtbot, tasks_cache): + @pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True) + def test_should_open_transcription_viewer_when_return_clicked( + self, qtbot, tasks_cache + ): window = MainWindow(tasks_cache=tasks_cache) qtbot.add_widget(window) table_widget: QTableWidget = window.findChild(QTableWidget) table_widget.selectRow(0) table_widget.keyPressEvent( - QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Return, Qt.KeyboardModifier.NoModifier, '\r')) + QKeyEvent( + QKeyEvent.Type.KeyPress, + Qt.Key.Key_Return, + Qt.KeyboardModifier.NoModifier, + "\r", + ) + ) transcription_viewer = window.findChild(TranscriptionViewerWidget) assert transcription_viewer is not None window.close() - @pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True) - def test_should_have_open_transcript_action_disabled_with_no_rows_selected(self, qtbot, tasks_cache): + @pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True) + def test_should_have_open_transcript_action_disabled_with_no_rows_selected( + self, qtbot, tasks_cache + ): window = MainWindow(tasks_cache=tasks_cache) qtbot.add_widget(window) @@ -253,20 +295,31 @@ class TestMainWindow: @staticmethod def _start_new_transcription(window: MainWindow): - with patch('PyQt6.QtWidgets.QFileDialog.getOpenFileNames') as open_file_names_mock: - open_file_names_mock.return_value = ([get_test_asset('whisper-french.mp3')], '') - new_transcription_action = TestMainWindow._get_toolbar_action(window, 'New Transcription') + with patch( + "PyQt6.QtWidgets.QFileDialog.getOpenFileNames" + ) as open_file_names_mock: + open_file_names_mock.return_value = ( + [get_test_asset("whisper-french.mp3")], + "", + ) + new_transcription_action = TestMainWindow._get_toolbar_action( + window, "New Transcription" + ) new_transcription_action.trigger() - file_transcriber_widget: FileTranscriberWidget = window.findChild(FileTranscriberWidget) + file_transcriber_widget: FileTranscriberWidget = window.findChild( + FileTranscriberWidget + ) run_button: QPushButton = file_transcriber_widget.findChild(QPushButton) run_button.click() @staticmethod - def _assert_task_status(table_widget: QTableWidget, row_index: int, expected_status: str): + def _assert_task_status( + table_widget: QTableWidget, row_index: int, expected_status: str + ): def assert_task_canceled(): assert table_widget.rowCount() > 0 - assert table_widget.item(row_index, 1).text() == 'whisper-french.mp3' + assert table_widget.item(row_index, 1).text() == "whisper-french.mp3" assert expected_status in table_widget.item(row_index, 2).text() return assert_task_canceled @@ -277,7 +330,7 @@ class TestMainWindow: return [action for action in toolbar.actions() if action.text() == text][0] -@pytest.fixture(scope='module', autouse=True) +@pytest.fixture(scope="module", autouse=True) def clear_settings(): settings = Settings() settings.clear() @@ -285,7 +338,7 @@ def clear_settings(): class TestAboutDialog: def test_should_check_for_updates(self, qtbot: QtBot): - reply = MockNetworkReply(data={'name': 'v' + VERSION}) + reply = MockNetworkReply(data={"name": "v" + VERSION}) manager = MockNetworkAccessManager(reply=reply) dialog = AboutDialog(network_access_manager=manager) qtbot.add_widget(dialog) @@ -296,41 +349,45 @@ class TestAboutDialog: with qtbot.wait_signal(dialog.network_access_manager.finished): dialog.check_updates_button.click() - mock_message_box_information.assert_called_with(dialog, '', "You're up to date!") + mock_message_box_information.assert_called_with( + dialog, "", "You're up to date!" + ) class TestAdvancedSettingsDialog: def test_should_update_advanced_settings(self, qtbot: QtBot): dialog = AdvancedSettingsDialog( - transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt')) + transcription_options=TranscriptionOptions( + temperature=(0.0, 0.8), initial_prompt="prompt" + ) + ) qtbot.add_widget(dialog) transcription_options_mock = Mock() - dialog.transcription_options_changed.connect( - transcription_options_mock) + dialog.transcription_options_changed.connect(transcription_options_mock) - assert dialog.windowTitle() == 'Advanced Settings' - assert dialog.temperature_line_edit.text() == '0.0, 0.8' - assert dialog.initial_prompt_text_edit.toPlainText() == 'prompt' + assert dialog.windowTitle() == "Advanced Settings" + assert dialog.temperature_line_edit.text() == "0.0, 0.8" + assert dialog.initial_prompt_text_edit.toPlainText() == "prompt" - dialog.temperature_line_edit.setText('0.0, 0.8, 1.0') - dialog.initial_prompt_text_edit.setPlainText('new prompt') + dialog.temperature_line_edit.setText("0.0, 0.8, 1.0") + dialog.initial_prompt_text_edit.setPlainText("new prompt") - assert transcription_options_mock.call_args[0][0].temperature == ( - 0.0, 0.8, 1.0) - assert transcription_options_mock.call_args[0][0].initial_prompt == 'new prompt' + assert transcription_options_mock.call_args[0][0].temperature == (0.0, 0.8, 1.0) + assert transcription_options_mock.call_args[0][0].initial_prompt == "new prompt" class TestTemperatureValidator: validator = TemperatureValidator(None) @pytest.mark.parametrize( - 'text,state', + "text,state", [ - ('0.0,0.5,1.0', QValidator.State.Acceptable), - ('0.0,0.5,', QValidator.State.Intermediate), - ('0.0,0.5,p', QValidator.State.Invalid), - ]) + ("0.0,0.5,1.0", QValidator.State.Acceptable), + ("0.0,0.5,", QValidator.State.Intermediate), + ("0.0,0.5,p", QValidator.State.Invalid), + ], + ) def test_should_validate_temperature(self, text: str, state: QValidator.State): assert self.validator.validate(text, 0)[0] == state @@ -339,9 +396,9 @@ class TestRecordingTranscriberWidget: def test_should_set_window_title(self, qtbot: QtBot): widget = RecordingTranscriberWidget() qtbot.add_widget(widget) - assert widget.windowTitle() == 'Live Recording' + assert widget.windowTitle() == "Live Recording" - @pytest.mark.skip(reason='Seg faults on CI') + @pytest.mark.skip(reason="Seg faults on CI") def test_should_transcribe(self, qtbot): widget = RecordingTranscriberWidget() qtbot.add_widget(widget) @@ -355,31 +412,37 @@ class TestRecordingTranscriberWidget: with qtbot.wait_signal(widget.transcription_thread.finished, timeout=60 * 1000): widget.stop_recording() - assert 'Welcome to Passe' in widget.text_box.toPlainText() + assert "Welcome to Passe" in widget.text_box.toPlainText() class TestHuggingFaceSearchLineEdit: def test_should_update_selected_model_on_type(self, qtbot: QtBot): - widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager()) + widget = HuggingFaceSearchLineEdit( + network_access_manager=self.network_access_manager() + ) qtbot.add_widget(widget) mock_model_selected = Mock() widget.model_selected.connect(mock_model_selected) self._set_text_and_wait_response(qtbot, widget) - mock_model_selected.assert_called_with('openai/whisper-tiny') + mock_model_selected.assert_called_with("openai/whisper-tiny") def test_should_show_list_of_models(self, qtbot: QtBot): - widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager()) + widget = HuggingFaceSearchLineEdit( + network_access_manager=self.network_access_manager() + ) qtbot.add_widget(widget) self._set_text_and_wait_response(qtbot, widget) assert widget.popup.count() > 0 - assert 'openai/whisper-tiny' in widget.popup.item(0).text() + assert "openai/whisper-tiny" in widget.popup.item(0).text() def test_should_select_model_from_list(self, qtbot: QtBot): - widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager()) + widget = HuggingFaceSearchLineEdit( + network_access_manager=self.network_access_manager() + ) qtbot.add_widget(widget) mock_model_selected = Mock() @@ -388,23 +451,35 @@ class TestHuggingFaceSearchLineEdit: self._set_text_and_wait_response(qtbot, widget) # press down arrow and enter to select next item - QApplication.sendEvent(widget.popup, - QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier)) - QApplication.sendEvent(widget.popup, - QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Enter, Qt.KeyboardModifier.NoModifier)) + QApplication.sendEvent( + widget.popup, + QKeyEvent( + QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier + ), + ) + QApplication.sendEvent( + widget.popup, + QKeyEvent( + QKeyEvent.Type.KeyPress, + Qt.Key.Key_Enter, + Qt.KeyboardModifier.NoModifier, + ), + ) - mock_model_selected.assert_called_with('openai/whisper-tiny.en') + mock_model_selected.assert_called_with("openai/whisper-tiny.en") @staticmethod def network_access_manager(): - reply = MockNetworkReply(data=[{'id': 'openai/whisper-tiny'}, {'id': 'openai/whisper-tiny.en'}]) + reply = MockNetworkReply( + data=[{"id": "openai/whisper-tiny"}, {"id": "openai/whisper-tiny.en"}] + ) return MockNetworkAccessManager(reply=reply) @staticmethod def _set_text_and_wait_response(qtbot: QtBot, widget: HuggingFaceSearchLineEdit): with qtbot.wait_signal(widget.network_manager.finished): - widget.setText('openai/whisper-tiny') - widget.textEdited.emit('openai/whisper-tiny') + widget.setText("openai/whisper-tiny") + widget.textEdited.emit("openai/whisper-tiny") class TestTranscriptionOptionsGroupBox: @@ -417,5 +492,7 @@ class TestTranscriptionOptionsGroupBox: widget.model_type_combo_box.setCurrentIndex(1) - transcription_options: TranscriptionOptions = mock_transcription_options_changed.call_args[0][0] + transcription_options: TranscriptionOptions = ( + mock_transcription_options_changed.call_args[0][0] + ) assert transcription_options.model.model_type == ModelType.WHISPER_CPP diff --git a/tests/mock_qt.py b/tests/mock_qt.py index 616939f9..15cc5837 100644 --- a/tests/mock_qt.py +++ b/tests/mock_qt.py @@ -1,4 +1,3 @@ - import json from typing import Optional @@ -10,10 +9,10 @@ class MockNetworkReply(QNetworkReply): def __init__(self, data: object, _: Optional[QObject] = None) -> None: self.data = data - def readAll(self) -> 'QByteArray': - return QByteArray(json.dumps(self.data).encode('utf-8')) + def readAll(self) -> "QByteArray": + return QByteArray(json.dumps(self.data).encode("utf-8")) - def error(self) -> 'QNetworkReply.NetworkError': + def error(self) -> "QNetworkReply.NetworkError": return QNetworkReply.NetworkError.NoError @@ -21,10 +20,12 @@ class MockNetworkAccessManager(QNetworkAccessManager): finished = pyqtSignal(object) reply: MockNetworkReply - def __init__(self, reply: MockNetworkReply, parent: Optional[QObject] = None) -> None: + def __init__( + self, reply: MockNetworkReply, parent: Optional[QObject] = None + ) -> None: super().__init__(parent) self.reply = reply - def get(self, _: 'QNetworkRequest') -> 'QNetworkReply': + def get(self, _: "QNetworkRequest") -> "QNetworkReply": self.finished.emit(self.reply) return self.reply diff --git a/tests/mock_sounddevice.py b/tests/mock_sounddevice.py index 299cb470..9b56826f 100644 --- a/tests/mock_sounddevice.py +++ b/tests/mock_sounddevice.py @@ -9,34 +9,90 @@ import sounddevice import whisper mock_query_devices = [ - {'name': 'Background Music', 'index': 0, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2, - 'default_low_input_latency': 0.01, - 'default_low_output_latency': 0.008, 'default_high_input_latency': 0.1, 'default_high_output_latency': 0.064, - 'default_samplerate': 8000.0}, - {'name': 'Background Music (UI Sounds)', 'index': 1, 'hostapi': 0, 'max_input_channels': 2, - 'max_output_channels': 2, 'default_low_input_latency': 0.01, - 'default_low_output_latency': 0.008, 'default_high_input_latency': 0.1, 'default_high_output_latency': 0.064, - 'default_samplerate': 8000.0}, - {'name': 'BlackHole 2ch', 'index': 2, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2, - 'default_low_input_latency': 0.01, - 'default_low_output_latency': 0.0013333333333333333, 'default_high_input_latency': 0.1, - 'default_high_output_latency': 0.010666666666666666, 'default_samplerate': 48000.0}, - {'name': 'MacBook Pro Microphone', 'index': 3, 'hostapi': 0, 'max_input_channels': 1, 'max_output_channels': 0, - 'default_low_input_latency': 0.034520833333333334, - 'default_low_output_latency': 0.01, 'default_high_input_latency': 0.043854166666666666, - 'default_high_output_latency': 0.1, 'default_samplerate': 48000.0}, - {'name': 'MacBook Pro Speakers', 'index': 4, 'hostapi': 0, 'max_input_channels': 0, 'max_output_channels': 2, - 'default_low_input_latency': 0.01, - 'default_low_output_latency': 0.0070416666666666666, 'default_high_input_latency': 0.1, - 'default_high_output_latency': 0.016375, 'default_samplerate': 48000.0}, - {'name': 'Null Audio Device', 'index': 5, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2, - 'default_low_input_latency': 0.01, - 'default_low_output_latency': 0.0014512471655328798, 'default_high_input_latency': 0.1, - 'default_high_output_latency': 0.011609977324263039, 'default_samplerate': 44100.0}, - {'name': 'Multi-Output Device', 'index': 6, 'hostapi': 0, 'max_input_channels': 0, 'max_output_channels': 2, - 'default_low_input_latency': 0.01, - 'default_low_output_latency': 0.0033333333333333335, 'default_high_input_latency': 0.1, - 'default_high_output_latency': 0.012666666666666666, 'default_samplerate': 48000.0}, + { + "name": "Background Music", + "index": 0, + "hostapi": 0, + "max_input_channels": 2, + "max_output_channels": 2, + "default_low_input_latency": 0.01, + "default_low_output_latency": 0.008, + "default_high_input_latency": 0.1, + "default_high_output_latency": 0.064, + "default_samplerate": 8000.0, + }, + { + "name": "Background Music (UI Sounds)", + "index": 1, + "hostapi": 0, + "max_input_channels": 2, + "max_output_channels": 2, + "default_low_input_latency": 0.01, + "default_low_output_latency": 0.008, + "default_high_input_latency": 0.1, + "default_high_output_latency": 0.064, + "default_samplerate": 8000.0, + }, + { + "name": "BlackHole 2ch", + "index": 2, + "hostapi": 0, + "max_input_channels": 2, + "max_output_channels": 2, + "default_low_input_latency": 0.01, + "default_low_output_latency": 0.0013333333333333333, + "default_high_input_latency": 0.1, + "default_high_output_latency": 0.010666666666666666, + "default_samplerate": 48000.0, + }, + { + "name": "MacBook Pro Microphone", + "index": 3, + "hostapi": 0, + "max_input_channels": 1, + "max_output_channels": 0, + "default_low_input_latency": 0.034520833333333334, + "default_low_output_latency": 0.01, + "default_high_input_latency": 0.043854166666666666, + "default_high_output_latency": 0.1, + "default_samplerate": 48000.0, + }, + { + "name": "MacBook Pro Speakers", + "index": 4, + "hostapi": 0, + "max_input_channels": 0, + "max_output_channels": 2, + "default_low_input_latency": 0.01, + "default_low_output_latency": 0.0070416666666666666, + "default_high_input_latency": 0.1, + "default_high_output_latency": 0.016375, + "default_samplerate": 48000.0, + }, + { + "name": "Null Audio Device", + "index": 5, + "hostapi": 0, + "max_input_channels": 2, + "max_output_channels": 2, + "default_low_input_latency": 0.01, + "default_low_output_latency": 0.0014512471655328798, + "default_high_input_latency": 0.1, + "default_high_output_latency": 0.011609977324263039, + "default_samplerate": 44100.0, + }, + { + "name": "Multi-Output Device", + "index": 6, + "hostapi": 0, + "max_input_channels": 0, + "max_output_channels": 2, + "default_low_input_latency": 0.01, + "default_low_output_latency": 0.0033333333333333335, + "default_high_input_latency": 0.1, + "default_high_output_latency": 0.012666666666666666, + "default_samplerate": 48000.0, + }, ] @@ -44,7 +100,12 @@ class MockInputStream(MagicMock): running = False thread: Thread - def __init__(self, callback: Callable[[np.ndarray, int, Any, sounddevice.CallbackFlags], None], *args, **kwargs): + def __init__( + self, + callback: Callable[[np.ndarray, int, Any, sounddevice.CallbackFlags], None], + *args, + **kwargs + ): super().__init__(spec=sounddevice.InputStream) self.thread = Thread(target=self.target) self.callback = callback @@ -54,7 +115,9 @@ class MockInputStream(MagicMock): def target(self): sample_rate = whisper.audio.SAMPLE_RATE - file_path = os.path.join(os.path.dirname(__file__), '../testdata/whisper-french.mp3') + file_path = os.path.join( + os.path.dirname(__file__), "../testdata/whisper-french.mp3" + ) audio = whisper.load_audio(file_path, sr=sample_rate) chunk_duration_secs = 1 @@ -65,7 +128,7 @@ class MockInputStream(MagicMock): while self.running: time.sleep(chunk_duration_secs) - chunk = audio[seek:seek + num_samples_in_chunk] + chunk = audio[seek : seek + num_samples_in_chunk] self.callback(chunk, 0, None, sounddevice.CallbackFlags()) seek += num_samples_in_chunk diff --git a/tests/model_loader.py b/tests/model_loader.py index d2558508..7a6599a1 100644 --- a/tests/model_loader.py +++ b/tests/model_loader.py @@ -1,14 +1,13 @@ from buzz.model_loader import TranscriptionModel, ModelDownloader - def get_model_path(transcription_model: TranscriptionModel) -> str: path = transcription_model.get_local_model_path() if path is not None: return path model_loader = ModelDownloader(model=transcription_model) - model_path = '' + model_path = "" def on_load_model(path: str): nonlocal model_path diff --git a/tests/transcriber_benchmarks_test.py b/tests/transcriber_benchmarks_test.py index ad7da1da..f58ac6d9 100644 --- a/tests/transcriber_benchmarks_test.py +++ b/tests/transcriber_benchmarks_test.py @@ -4,20 +4,32 @@ from unittest.mock import Mock import pytest from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel -from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, Task, WhisperCppFileTranscriber, - TranscriptionOptions, WhisperFileTranscriber, FileTranscriber) +from buzz.transcriber import ( + FileTranscriptionOptions, + FileTranscriptionTask, + Task, + WhisperCppFileTranscriber, + TranscriptionOptions, + WhisperFileTranscriber, + FileTranscriber, +) from tests.model_loader import get_model_path def get_task(model: TranscriptionModel): file_transcription_options = FileTranscriptionOptions( - file_paths=['testdata/whisper-french.mp3']) - transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE, - word_level_timings=False, - model=model) + file_paths=["testdata/whisper-french.mp3"] + ) + transcription_options = TranscriptionOptions( + language="fr", task=Task.TRANSCRIBE, word_level_timings=False, model=model + ) model_path = get_model_path(transcription_options.model) - return FileTranscriptionTask(file_path='testdata/audio-long.mp3', transcription_options=transcription_options, - file_transcription_options=file_transcription_options, model_path=model_path) + return FileTranscriptionTask( + file_path="testdata/audio-long.mp3", + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + model_path=model_path, + ) def transcribe(qtbot, transcriber: FileTranscriber): @@ -31,24 +43,53 @@ def transcribe(qtbot, transcriber: FileTranscriber): @pytest.mark.parametrize( - 'transcriber', + "transcriber", [ pytest.param( - WhisperCppFileTranscriber(task=(get_task( - TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY)))), - id="Whisper.cpp - Tiny"), - pytest.param( - WhisperFileTranscriber(task=(get_task( - TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)))), - id="Whisper - Tiny"), - pytest.param( - WhisperFileTranscriber(task=(get_task( - TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY)))), - id="Faster Whisper - Tiny", - marks=pytest.mark.skipif(platform.system() == 'Darwin', - reason='Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087') + WhisperCppFileTranscriber( + task=( + get_task( + TranscriptionModel( + model_type=ModelType.WHISPER_CPP, + whisper_model_size=WhisperModelSize.TINY, + ) + ) + ) + ), + id="Whisper.cpp - Tiny", ), - ]) + pytest.param( + WhisperFileTranscriber( + task=( + get_task( + TranscriptionModel( + model_type=ModelType.WHISPER, + whisper_model_size=WhisperModelSize.TINY, + ) + ) + ) + ), + id="Whisper - Tiny", + ), + pytest.param( + WhisperFileTranscriber( + task=( + get_task( + TranscriptionModel( + model_type=ModelType.FASTER_WHISPER, + whisper_model_size=WhisperModelSize.TINY, + ) + ) + ) + ), + id="Faster Whisper - Tiny", + marks=pytest.mark.skipif( + platform.system() == "Darwin", + reason="Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087", + ), + ), + ], +) def test_should_transcribe_and_benchmark(qtbot, benchmark, transcriber): segments = benchmark(transcribe, qtbot, transcriber) assert len(segments) > 0 diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index 4ed8b11d..a0e98636 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -11,26 +11,42 @@ from PyQt6.QtCore import QThread from pytestqt.qtbot import QtBot from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel -from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, Segment, Task, WhisperCpp, WhisperCppFileTranscriber, - WhisperFileTranscriber, - get_default_output_file_path, to_timestamp, - whisper_cpp_params, write_output, TranscriptionOptions) +from buzz.transcriber import ( + FileTranscriptionOptions, + FileTranscriptionTask, + OutputFormat, + Segment, + Task, + WhisperCpp, + WhisperCppFileTranscriber, + WhisperFileTranscriber, + get_default_output_file_path, + to_timestamp, + whisper_cpp_params, + write_output, + TranscriptionOptions, +) from buzz.recording_transcriber import RecordingTranscriber from tests.mock_sounddevice import MockInputStream from tests.model_loader import get_model_path class TestRecordingTranscriber: - @pytest.mark.skip(reason='Hanging') + @pytest.mark.skip(reason="Hanging") def test_should_transcribe(self, qtbot): thread = QThread() - transcription_model = TranscriptionModel(model_type=ModelType.WHISPER_CPP, - whisper_model_size=WhisperModelSize.TINY) + transcription_model = TranscriptionModel( + model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY + ) - transcriber = RecordingTranscriber(transcription_options=TranscriptionOptions( - model=transcription_model, language='fr', task=Task.TRANSCRIBE), - input_device_index=0, sample_rate=16_000) + transcriber = RecordingTranscriber( + transcription_options=TranscriptionOptions( + model=transcription_model, language="fr", task=Task.TRANSCRIBE + ), + input_device_index=0, + sample_rate=16_000, + ) transcriber.moveToThread(thread) thread.finished.connect(thread.deleteLater) @@ -41,39 +57,55 @@ class TestRecordingTranscriber: transcriber.finished.connect(thread.quit) transcriber.finished.connect(transcriber.deleteLater) - with patch('sounddevice.InputStream', side_effect=MockInputStream), patch( - 'sounddevice.check_input_settings'), qtbot.wait_signal(transcriber.transcription, timeout=60 * 1000): + with patch("sounddevice.InputStream", side_effect=MockInputStream), patch( + "sounddevice.check_input_settings" + ), qtbot.wait_signal(transcriber.transcription, timeout=60 * 1000): thread.start() with qtbot.wait_signal(thread.finished, timeout=60 * 1000): transcriber.stop_recording() text = mock_transcription.call_args[0][0] - assert 'Bienvenue dans Passe' in text + assert "Bienvenue dans Passe" in text class TestWhisperCppFileTranscriber: @pytest.mark.parametrize( - 'word_level_timings,expected_segments', + "word_level_timings,expected_segments", [ - (False, [Segment(0, 6560, - 'Bienvenue dans Passe-Relle. Un podcast pensé pour')]), - (True, [Segment(30, 330, 'Bien'), Segment(330, 740, 'venue')]) - ]) - def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]): + ( + False, + [Segment(0, 6560, "Bienvenue dans Passe-Relle. Un podcast pensé pour")], + ), + (True, [Segment(30, 330, "Bien"), Segment(330, 740, "venue")]), + ], + ) + def test_transcribe( + self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment] + ): file_transcription_options = FileTranscriptionOptions( - file_paths=['testdata/whisper-french.mp3']) - transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE, - word_level_timings=word_level_timings, - model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, - whisper_model_size=WhisperModelSize.TINY)) + file_paths=["testdata/whisper-french.mp3"] + ) + transcription_options = TranscriptionOptions( + language="fr", + task=Task.TRANSCRIBE, + word_level_timings=word_level_timings, + model=TranscriptionModel( + model_type=ModelType.WHISPER_CPP, + whisper_model_size=WhisperModelSize.TINY, + ), + ) model_path = get_model_path(transcription_options.model) transcriber = WhisperCppFileTranscriber( - task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3', - transcription_options=transcription_options, - file_transcription_options=file_transcription_options, model_path=model_path)) - mock_progress = Mock(side_effect=lambda value: print('progress: ', value)) + task=FileTranscriptionTask( + file_path="testdata/whisper-french.mp3", + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + model_path=model_path, + ) + ) + mock_progress = Mock(side_effect=lambda value: print("progress: ", value)) mock_completed = Mock() transcriber.progress.connect(mock_progress) transcriber.completed.connect(mock_completed) @@ -81,7 +113,11 @@ class TestWhisperCppFileTranscriber: transcriber.run() mock_progress.assert_called() - segments = [segment for segment in mock_completed.call_args[0][0] if len(segment.text) > 0] + segments = [ + segment + for segment in mock_completed.call_args[0][0] + if len(segment.text) > 0 + ] for i, expected_segment in enumerate(expected_segments): assert expected_segment.start == segments[i].start assert expected_segment.end == segments[i].end @@ -90,82 +126,164 @@ class TestWhisperCppFileTranscriber: class TestWhisperFileTranscriber: @pytest.mark.parametrize( - 'output_format,expected_file_path,default_output_file_name', + "output_format,expected_file_path,default_output_file_name", [ - (OutputFormat.SRT, '/a/b/c-translate--Whisper-tiny.srt', '{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}'), - ]) - def test_default_output_file2(self, output_format: OutputFormat, expected_file_path: str, default_output_file_name: str): + ( + OutputFormat.SRT, + "/a/b/c-translate--Whisper-tiny.srt", + "{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}", + ), + ], + ) + def test_default_output_file2( + self, + output_format: OutputFormat, + expected_file_path: str, + default_output_file_name: str, + ): file_path = get_default_output_file_path( task=FileTranscriptionTask( - file_path='/a/b/c.mp4', + file_path="/a/b/c.mp4", transcription_options=TranscriptionOptions(task=Task.TRANSLATE), - file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name=default_output_file_name), - model_path=''), - output_format=output_format) + file_transcription_options=FileTranscriptionOptions( + file_paths=[], default_output_file_name=default_output_file_name + ), + model_path="", + ), + output_format=output_format, + ) assert file_path == expected_file_path def test_default_output_file(self): srt = get_default_output_file_path( task=FileTranscriptionTask( - file_path='/a/b/c.mp4', + file_path="/a/b/c.mp4", transcription_options=TranscriptionOptions(task=Task.TRANSLATE), - file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'), - model_path=''), - output_format=OutputFormat.TXT) - assert srt.startswith('/a/b/c (Translated on ') - assert srt.endswith('.txt') + file_transcription_options=FileTranscriptionOptions( + file_paths=[], + default_output_file_name="{{ input_file_name }} (Translated on {{ date_time }})", + ), + model_path="", + ), + output_format=OutputFormat.TXT, + ) + assert srt.startswith("/a/b/c (Translated on ") + assert srt.endswith(".txt") srt = get_default_output_file_path( task=FileTranscriptionTask( - file_path='/a/b/c.mp4', + file_path="/a/b/c.mp4", transcription_options=TranscriptionOptions(task=Task.TRANSLATE), - file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'), - model_path=''), - output_format=OutputFormat.SRT) - assert srt.startswith('/a/b/c (Translated on ') - assert srt.endswith('.srt') + file_transcription_options=FileTranscriptionOptions( + file_paths=[], + default_output_file_name="{{ input_file_name }} (Translated on {{ date_time }})", + ), + model_path="", + ), + output_format=OutputFormat.SRT, + ) + assert srt.startswith("/a/b/c (Translated on ") + assert srt.endswith(".srt") @pytest.mark.parametrize( - 'word_level_timings,expected_segments,model,check_progress', + "word_level_timings,expected_segments,model,check_progress", [ - (False, [Segment(0, 6560, - ' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances')], - TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True), - (True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')], - TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True), - (False, [Segment(0, 8517, - ' Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances ' - 'et des apprenances de français.')], - TranscriptionModel(model_type=ModelType.HUGGING_FACE, - hugging_face_model_id='openai/whisper-tiny'), False), + ( + False, + [ + Segment( + 0, + 6560, + " Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances", + ) + ], + TranscriptionModel( + model_type=ModelType.WHISPER, + whisper_model_size=WhisperModelSize.TINY, + ), + True, + ), + ( + True, + [Segment(40, 299, " Bien"), Segment(299, 329, "venue dans")], + TranscriptionModel( + model_type=ModelType.WHISPER, + whisper_model_size=WhisperModelSize.TINY, + ), + True, + ), + ( + False, + [ + Segment( + 0, + 8517, + " Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances " + "et des apprenances de français.", + ) + ], + TranscriptionModel( + model_type=ModelType.HUGGING_FACE, + hugging_face_model_id="openai/whisper-tiny", + ), + False, + ), pytest.param( - False, [Segment(start=0, end=8400, - text=' Bienvenue dans Passrel, un podcast pensé pour éveiller la curiosité des apprenances et des apprenances de français.')], - TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY), True, - marks=pytest.mark.skipif(platform.system() == 'Darwin', - reason='Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087') - ) - ]) - def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment], - model: TranscriptionModel, check_progress): + False, + [ + Segment( + start=0, + end=8400, + text=" Bienvenue dans Passrel, un podcast pensé pour éveiller la curiosité des apprenances et des apprenances de français.", + ) + ], + TranscriptionModel( + model_type=ModelType.FASTER_WHISPER, + whisper_model_size=WhisperModelSize.TINY, + ), + True, + marks=pytest.mark.skipif( + platform.system() == "Darwin", + reason="Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087", + ), + ), + ], + ) + def test_transcribe( + self, + qtbot: QtBot, + word_level_timings: bool, + expected_segments: List[Segment], + model: TranscriptionModel, + check_progress, + ): mock_progress = Mock() mock_completed = Mock() - transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE, - word_level_timings=word_level_timings, - model=model) + transcription_options = TranscriptionOptions( + language="fr", + task=Task.TRANSCRIBE, + word_level_timings=word_level_timings, + model=model, + ) model_path = get_model_path(transcription_options.model) - file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'testdata/whisper-french.mp3')) - file_transcription_options = FileTranscriptionOptions( - file_paths=[file_path]) + file_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "testdata/whisper-french.mp3") + ) + file_transcription_options = FileTranscriptionOptions(file_paths=[file_path]) transcriber = WhisperFileTranscriber( - task=FileTranscriptionTask(transcription_options=transcription_options, - file_transcription_options=file_transcription_options, - file_path=file_path, model_path=model_path)) + task=FileTranscriptionTask( + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + file_path=file_path, + model_path=model_path, + ) + ) transcriber.progress.connect(mock_progress) transcriber.completed.connect(mock_completed) - with qtbot.wait_signal(transcriber.progress, timeout=10 * 6000), qtbot.wait_signal(transcriber.completed, - timeout=10 * 6000): + with qtbot.wait_signal( + transcriber.progress, timeout=10 * 6000 + ), qtbot.wait_signal(transcriber.completed, timeout=10 * 6000): transcriber.run() # Skip checking progress... @@ -182,26 +300,37 @@ class TestWhisperFileTranscriber: mock_completed.assert_called() segments = mock_completed.call_args[0][0] assert len(segments) >= len(expected_segments) - for (i, expected_segment) in enumerate(expected_segments): + for i, expected_segment in enumerate(expected_segments): assert segments[i] == expected_segment @pytest.mark.skip() def test_transcribe_stop(self): - output_file_path = os.path.join(tempfile.gettempdir(), 'whisper.txt') + output_file_path = os.path.join(tempfile.gettempdir(), "whisper.txt") if os.path.exists(output_file_path): os.remove(output_file_path) file_transcription_options = FileTranscriptionOptions( - file_paths=['testdata/whisper-french.mp3']) + file_paths=["testdata/whisper-french.mp3"] + ) transcription_options = TranscriptionOptions( - language='fr', task=Task.TRANSCRIBE, word_level_timings=False, - model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY)) + language="fr", + task=Task.TRANSCRIBE, + word_level_timings=False, + model=TranscriptionModel( + model_type=ModelType.WHISPER_CPP, + whisper_model_size=WhisperModelSize.TINY, + ), + ) model_path = get_model_path(transcription_options.model) transcriber = WhisperFileTranscriber( - task=FileTranscriptionTask(model_path=model_path, transcription_options=transcription_options, - file_transcription_options=file_transcription_options, - file_path='testdata/whisper-french.mp3')) + task=FileTranscriptionTask( + model_path=model_path, + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + file_path="testdata/whisper-french.mp3", + ) + ) transcriber.run() time.sleep(1) transcriber.stop() @@ -212,40 +341,54 @@ class TestWhisperFileTranscriber: class TestToTimestamp: def test_to_timestamp(self): - assert to_timestamp(0) == '00:00:00.000' - assert to_timestamp(123456789) == '34:17:36.789' + assert to_timestamp(0) == "00:00:00.000" + assert to_timestamp(123456789) == "34:17:36.789" class TestWhisperCpp: def test_transcribe(self): transcription_options = TranscriptionOptions( - model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY)) + model=TranscriptionModel( + model_type=ModelType.WHISPER_CPP, + whisper_model_size=WhisperModelSize.TINY, + ) + ) model_path = get_model_path(transcription_options.model) whisper_cpp = WhisperCpp(model=model_path) params = whisper_cpp_params( - language='fr', task=Task.TRANSCRIBE, word_level_timings=False) + language="fr", task=Task.TRANSCRIBE, word_level_timings=False + ) result = whisper_cpp.transcribe( - audio='testdata/whisper-french.mp3', params=params) + audio="testdata/whisper-french.mp3", params=params + ) - assert 'Bienvenue dans Passe' in result['text'] + assert "Bienvenue dans Passe" in result["text"] @pytest.mark.parametrize( - 'output_format,output_text', + "output_format,output_text", [ - (OutputFormat.TXT, 'Bien\nvenue dans\n'), + (OutputFormat.TXT, "Bien\nvenue dans\n"), ( - OutputFormat.SRT, - '1\n00:00:00,040 --> 00:00:00,299\nBien\n\n2\n00:00:00,299 --> 00:00:00,329\nvenue dans\n\n'), - (OutputFormat.VTT, - 'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'), - ]) -def test_write_output(tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str): - output_file_path = tmp_path / 'whisper.txt' - segments = [Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')] + OutputFormat.SRT, + "1\n00:00:00,040 --> 00:00:00,299\nBien\n\n2\n00:00:00,299 --> 00:00:00,329\nvenue dans\n\n", + ), + ( + OutputFormat.VTT, + "WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n", + ), + ], +) +def test_write_output( + tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str +): + output_file_path = tmp_path / "whisper.txt" + segments = [Segment(40, 299, "Bien"), Segment(299, 329, "venue dans")] - write_output(path=str(output_file_path), segments=segments, output_format=output_format) + write_output( + path=str(output_file_path), segments=segments, output_format=output_format + ) - output_file = open(output_file_path, 'r', encoding='utf-8') + output_file = open(output_file_path, "r", encoding="utf-8") assert output_text == output_file.read() diff --git a/tests/transformers_whisper_test.py b/tests/transformers_whisper_test.py index 73471001..c06bd2fe 100644 --- a/tests/transformers_whisper_test.py +++ b/tests/transformers_whisper_test.py @@ -3,8 +3,9 @@ from buzz.transformers_whisper import load_model class TestTransformersWhisper: def test_should_transcribe(self): - model = load_model('openai/whisper-tiny') + model = load_model("openai/whisper-tiny") result = model.transcribe( - audio='testdata/whisper-french.mp3', language='fr', task='transcribe') + audio="testdata/whisper-french.mp3", language="fr", task="transcribe" + ) - assert 'Bienvenue dans Passe' in result['text'] + assert "Bienvenue dans Passe" in result["text"] diff --git a/tests/widgets/file_transcriber_widget_test.py b/tests/widgets/file_transcriber_widget_test.py index 431b4321..432c533b 100644 --- a/tests/widgets/file_transcriber_widget_test.py +++ b/tests/widgets/file_transcriber_widget_test.py @@ -9,13 +9,19 @@ from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidg class TestFileTranscriberWidget: def test_should_set_window_title(self, qtbot: QtBot): widget = FileTranscriberWidget( - file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None) + file_paths=["testdata/whisper-french.mp3"], + default_output_file_name="", + parent=None, + ) qtbot.add_widget(widget) - assert widget.windowTitle() == 'whisper-french.mp3' + assert widget.windowTitle() == "whisper-french.mp3" def test_should_emit_triggered_event(self, qtbot: QtBot): widget = FileTranscriberWidget( - file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None) + file_paths=["testdata/whisper-french.mp3"], + default_output_file_name="", + parent=None, + ) qtbot.add_widget(widget) mock_triggered = Mock() @@ -24,9 +30,11 @@ class TestFileTranscriberWidget: with qtbot.wait_signal(widget.triggered, timeout=30 * 1000): qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton) - transcription_options, file_transcription_options, model_path = mock_triggered.call_args[ - 0][0] + ( + transcription_options, + file_transcription_options, + model_path, + ) = mock_triggered.call_args[0][0] assert transcription_options.language is None - assert file_transcription_options.file_paths == [ - 'testdata/whisper-french.mp3'] + assert file_transcription_options.file_paths == ["testdata/whisper-french.mp3"] assert len(model_path) > 0 diff --git a/tests/widgets/model_download_progress_dialog.py b/tests/widgets/model_download_progress_dialog.py index 0248ea7c..3a47331a 100644 --- a/tests/widgets/model_download_progress_dialog.py +++ b/tests/widgets/model_download_progress_dialog.py @@ -8,7 +8,7 @@ class TestModelDownloadProgressDialog: def test_should_show_dialog(self, qtbot): dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None) qtbot.add_widget(dialog) - assert dialog.labelText() == 'Downloading model (0%)' + assert dialog.labelText() == "Downloading model (0%)" def test_should_update_label_on_progress(self, qtbot): dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None) @@ -16,12 +16,10 @@ class TestModelDownloadProgressDialog: dialog.set_value(0.0) dialog.set_value(0.01) - assert dialog.labelText().startswith( - 'Downloading model (1%') + assert dialog.labelText().startswith("Downloading model (1%") dialog.set_value(0.1) - assert dialog.labelText().startswith( - 'Downloading model (10%') + assert dialog.labelText().startswith("Downloading model (10%") # Other windows should not be processing while models are being downloaded def test_should_be_an_application_modal(self, qtbot): diff --git a/tests/widgets/model_type_combo_box_test.py b/tests/widgets/model_type_combo_box_test.py index 43865b0b..809384c0 100644 --- a/tests/widgets/model_type_combo_box_test.py +++ b/tests/widgets/model_type_combo_box_test.py @@ -7,8 +7,8 @@ class TestModelTypeComboBox: qtbot.add_widget(widget) assert widget.count() == 5 - assert widget.itemText(0) == 'Whisper' - assert widget.itemText(1) == 'Whisper.cpp' - assert widget.itemText(2) == 'Hugging Face' - assert widget.itemText(3) == 'Faster Whisper' - assert widget.itemText(4) == 'OpenAI Whisper API' + assert widget.itemText(0) == "Whisper" + assert widget.itemText(1) == "Whisper.cpp" + assert widget.itemText(2) == "Hugging Face" + assert widget.itemText(3) == "Faster Whisper" + assert widget.itemText(4) == "OpenAI Whisper API" diff --git a/tests/widgets/openai_api_key_line_edit_test.py b/tests/widgets/openai_api_key_line_edit_test.py index 3d2b0a88..6383763b 100644 --- a/tests/widgets/openai_api_key_line_edit_test.py +++ b/tests/widgets/openai_api_key_line_edit_test.py @@ -3,14 +3,14 @@ from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit class TestOpenAIAPIKeyLineEdit: def test_should_emit_key_changed(self, qtbot): - line_edit = OpenAIAPIKeyLineEdit(key='') + line_edit = OpenAIAPIKeyLineEdit(key="") qtbot.add_widget(line_edit) with qtbot.wait_signal(line_edit.key_changed): - line_edit.setText('abcdefg') + line_edit.setText("abcdefg") def test_should_toggle_visibility(self, qtbot): - line_edit = OpenAIAPIKeyLineEdit(key='') + line_edit = OpenAIAPIKeyLineEdit(key="") qtbot.add_widget(line_edit) assert line_edit.echoMode() == OpenAIAPIKeyLineEdit.EchoMode.Password diff --git a/tests/widgets/preferences_dialog/general_preferences_widget_test.py b/tests/widgets/preferences_dialog/general_preferences_widget_test.py index 5261fa2f..cc0402d0 100644 --- a/tests/widgets/preferences_dialog/general_preferences_widget_test.py +++ b/tests/widgets/preferences_dialog/general_preferences_widget_test.py @@ -4,32 +4,35 @@ import pytest from PyQt6.QtWidgets import QPushButton, QMessageBox, QLineEdit from buzz.store.keyring_store import KeyringStore -from buzz.widgets.preferences_dialog.general_preferences_widget import \ - GeneralPreferencesWidget +from buzz.widgets.preferences_dialog.general_preferences_widget import ( + GeneralPreferencesWidget, +) class TestGeneralPreferencesWidget: def test_should_disable_test_button_if_no_api_key(self, qtbot): - widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(''), - default_export_file_name='') + widget = GeneralPreferencesWidget( + keyring_store=self.get_keyring_store(""), default_export_file_name="" + ) qtbot.add_widget(widget) test_button = widget.findChild(QPushButton) assert isinstance(test_button, QPushButton) - assert test_button.text() == 'Test' + assert test_button.text() == "Test" assert not test_button.isEnabled() line_edit = widget.findChild(QLineEdit) assert isinstance(line_edit, QLineEdit) - line_edit.setText('123') + line_edit.setText("123") assert test_button.isEnabled() def test_should_test_openai_api_key(self, qtbot): widget = GeneralPreferencesWidget( - keyring_store=self.get_keyring_store('wrong-api-key'), - default_export_file_name='') + keyring_store=self.get_keyring_store("wrong-api-key"), + default_export_file_name="", + ) qtbot.add_widget(widget) test_button = widget.findChild(QPushButton) @@ -42,9 +45,11 @@ class TestGeneralPreferencesWidget: def mock_called(): mock.assert_called() - assert mock.call_args[0][1] == 'OpenAI API Key Test' - assert mock.call_args[0][ - 2] == 'Incorrect API key provided: wrong-ap*-key. You can find your API key at https://platform.openai.com/account/api-keys.' + assert mock.call_args[0][1] == "OpenAI API Key Test" + assert ( + mock.call_args[0][2] + == "Incorrect API key provided: wrong-ap*-key. You can find your API key at https://platform.openai.com/account/api-keys." + ) qtbot.waitUntil(mock_called) diff --git a/tests/widgets/preferences_dialog/models_preferences_widget_test.py b/tests/widgets/preferences_dialog/models_preferences_widget_test.py index ea668986..30940239 100644 --- a/tests/widgets/preferences_dialog/models_preferences_widget_test.py +++ b/tests/widgets/preferences_dialog/models_preferences_widget_test.py @@ -5,16 +5,20 @@ from PyQt6.QtCore import Qt from PyQt6.QtWidgets import QComboBox, QPushButton from pytestqt.qtbot import QtBot -from buzz.model_loader import get_whisper_file_path, WhisperModelSize, \ - TranscriptionModel, \ - ModelType -from buzz.widgets.preferences_dialog.models_preferences_widget import \ - ModelsPreferencesWidget +from buzz.model_loader import ( + get_whisper_file_path, + WhisperModelSize, + TranscriptionModel, + ModelType, +) +from buzz.widgets.preferences_dialog.models_preferences_widget import ( + ModelsPreferencesWidget, +) from tests.model_loader import get_model_path class TestModelsPreferencesWidget: - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def clear_model_cache(self): file_path = get_whisper_file_path(size=WhisperModelSize.TINY) if os.path.isfile(file_path): @@ -25,10 +29,10 @@ class TestModelsPreferencesWidget: qtbot.add_widget(widget) first_item = widget.model_list_widget.topLevelItem(0) - assert first_item.text(0) == 'Downloaded' + assert first_item.text(0) == "Downloaded" second_item = widget.model_list_widget.topLevelItem(1) - assert second_item.text(0) == 'Available for Download' + assert second_item.text(0) == "Available for Download" def test_should_change_model_type(self, qtbot): widget = ModelsPreferencesWidget() @@ -36,36 +40,38 @@ class TestModelsPreferencesWidget: combo_box = widget.findChild(QComboBox) assert isinstance(combo_box, QComboBox) - combo_box.setCurrentText('Faster Whisper') + combo_box.setCurrentText("Faster Whisper") first_item = widget.model_list_widget.topLevelItem(0) - assert first_item.text(0) == 'Downloaded' + assert first_item.text(0) == "Downloaded" second_item = widget.model_list_widget.topLevelItem(1) - assert second_item.text(0) == 'Available for Download' + assert second_item.text(0) == "Available for Download" def test_should_download_model(self, qtbot: QtBot, clear_model_cache): # make progress dialog non-modal to unblock qtbot.wait_until widget = ModelsPreferencesWidget( - progress_dialog_modality=Qt.WindowModality.NonModal) + progress_dialog_modality=Qt.WindowModality.NonModal + ) qtbot.add_widget(widget) - model = TranscriptionModel(model_type=ModelType.WHISPER, - whisper_model_size=WhisperModelSize.TINY) + model = TranscriptionModel( + model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY + ) assert model.get_local_model_path() is None available_item = widget.model_list_widget.topLevelItem(1) - assert available_item.text(0) == 'Available for Download' + assert available_item.text(0) == "Available for Download" tiny_item = available_item.child(0) - assert tiny_item.text(0) == 'Tiny' + assert tiny_item.text(0) == "Tiny" tiny_item.setSelected(True) - download_button = widget.findChild(QPushButton, 'DownloadButton') + download_button = widget.findChild(QPushButton, "DownloadButton") assert isinstance(download_button, QPushButton) - assert download_button.text() == 'Download' + assert download_button.text() == "Download" download_button.click() def downloaded_model(): @@ -73,22 +79,26 @@ class TestModelsPreferencesWidget: _downloaded_item = widget.model_list_widget.topLevelItem(0) assert _downloaded_item.childCount() > 0 - assert _downloaded_item.child(0).text(0) == 'Tiny' + assert _downloaded_item.child(0).text(0) == "Tiny" _available_item = widget.model_list_widget.topLevelItem(1) - assert _available_item.childCount() == 0 or _available_item.child(0).text( - 0) != 'Tiny' + assert ( + _available_item.childCount() == 0 + or _available_item.child(0).text(0) != "Tiny" + ) # model file exists assert os.path.isfile(get_whisper_file_path(size=model.whisper_model_size)) qtbot.wait_until(callback=downloaded_model, timeout=60_000) - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def whisper_tiny_model_path(self) -> str: - return get_model_path(transcription_model=TranscriptionModel( - model_type=ModelType.WHISPER, - whisper_model_size=WhisperModelSize.TINY)) + return get_model_path( + transcription_model=TranscriptionModel( + model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY + ) + ) def test_should_show_downloaded_model(self, qtbot, whisper_tiny_model_path): widget = ModelsPreferencesWidget() @@ -96,15 +106,16 @@ class TestModelsPreferencesWidget: qtbot.add_widget(widget) available_item = widget.model_list_widget.topLevelItem(0) - assert available_item.text(0) == 'Downloaded' + assert available_item.text(0) == "Downloaded" tiny_item = available_item.child(0) - assert tiny_item.text(0) == 'Tiny' + assert tiny_item.text(0) == "Tiny" tiny_item.setSelected(True) - delete_button = widget.findChild(QPushButton, 'DeleteButton') + delete_button = widget.findChild(QPushButton, "DeleteButton") assert delete_button.isVisible() - show_file_location_button = widget.findChild(QPushButton, - 'ShowFileLocationButton') + show_file_location_button = widget.findChild( + QPushButton, "ShowFileLocationButton" + ) assert show_file_location_button.isVisible() diff --git a/tests/widgets/preferences_dialog/preferences_dialog_test.py b/tests/widgets/preferences_dialog/preferences_dialog_test.py index 0dacb484..7c15402c 100644 --- a/tests/widgets/preferences_dialog/preferences_dialog_test.py +++ b/tests/widgets/preferences_dialog/preferences_dialog_test.py @@ -6,14 +6,14 @@ from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog class TestPreferencesDialog: def test_create(self, qtbot: QtBot): - dialog = PreferencesDialog(shortcuts={}, default_export_file_name='') + dialog = PreferencesDialog(shortcuts={}, default_export_file_name="") qtbot.add_widget(dialog) - assert dialog.windowTitle() == 'Preferences' + assert dialog.windowTitle() == "Preferences" tab_widget = dialog.findChild(QTabWidget) assert isinstance(tab_widget, QTabWidget) assert tab_widget.count() == 3 - assert tab_widget.tabText(0) == 'General' - assert tab_widget.tabText(1) == 'Models' - assert tab_widget.tabText(2) == 'Shortcuts' + assert tab_widget.tabText(0) == "General" + assert tab_widget.tabText(1) == "Models" + assert tab_widget.tabText(2) == "Shortcuts" diff --git a/tests/widgets/shortcuts_editor_widget_test.py b/tests/widgets/shortcuts_editor_widget_test.py index ea23e44d..ba833fcf 100644 --- a/tests/widgets/shortcuts_editor_widget_test.py +++ b/tests/widgets/shortcuts_editor_widget_test.py @@ -1,14 +1,17 @@ from PyQt6.QtWidgets import QPushButton, QLabel from buzz.settings.shortcut import Shortcut -from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import \ - ShortcutsEditorPreferencesWidget +from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import ( + ShortcutsEditorPreferencesWidget, +) from buzz.widgets.sequence_edit import SequenceEdit class TestShortcutsEditorWidget: def test_should_reset_to_defaults(self, qtbot): - widget = ShortcutsEditorPreferencesWidget(shortcuts=Shortcut.get_default_shortcuts()) + widget = ShortcutsEditorPreferencesWidget( + shortcuts=Shortcut.get_default_shortcuts() + ) qtbot.add_widget(widget) reset_button = widget.findChild(QPushButton) @@ -19,12 +22,13 @@ class TestShortcutsEditorWidget: sequence_edits = widget.findChildren(SequenceEdit) expected = ( - ('Open Record Window', 'Ctrl+R'), - ('Import File', 'Ctrl+O'), - ('Open Preferences Window', 'Ctrl+,'), - ('Open Transcript Viewer', 'Ctrl+E'), - ('Clear History', 'Ctrl+S'), - ('Cancel Transcription', 'Ctrl+X')) + ("Open Record Window", "Ctrl+R"), + ("Import File", "Ctrl+O"), + ("Open Preferences Window", "Ctrl+,"), + ("Open Transcript Viewer", "Ctrl+E"), + ("Clear History", "Ctrl+S"), + ("Cancel Transcription", "Ctrl+X"), + ) for i, (label, sequence_edit) in enumerate(zip(labels, sequence_edits)): assert isinstance(label, QLabel) diff --git a/tests/widgets/transcription_tasks_table_widget_test.py b/tests/widgets/transcription_tasks_table_widget_test.py index 11838654..abc449c1 100644 --- a/tests/widgets/transcription_tasks_table_widget_test.py +++ b/tests/widgets/transcription_tasks_table_widget_test.py @@ -2,57 +2,70 @@ import datetime from pytestqt.qtbot import QtBot -from buzz.transcriber import FileTranscriptionTask, TranscriptionOptions, FileTranscriptionOptions +from buzz.transcriber import ( + FileTranscriptionTask, + TranscriptionOptions, + FileTranscriptionOptions, +) from buzz.widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget class TestTranscriptionTasksTableWidget: - def test_upsert_task(self, qtbot: QtBot): widget = TranscriptionTasksTableWidget() qtbot.add_widget(widget) - task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3', - transcription_options=TranscriptionOptions(), - file_transcription_options=FileTranscriptionOptions( - file_paths=['testdata/whisper-french.mp3']), model_path='', - status=FileTranscriptionTask.Status.QUEUED) + task = FileTranscriptionTask( + id=0, + file_path="testdata/whisper-french.mp3", + transcription_options=TranscriptionOptions(), + file_transcription_options=FileTranscriptionOptions( + file_paths=["testdata/whisper-french.mp3"] + ), + model_path="", + status=FileTranscriptionTask.Status.QUEUED, + ) task.queued_at = datetime.datetime(2023, 4, 12, 0, 0, 0) task.started_at = datetime.datetime(2023, 4, 12, 0, 0, 5) widget.upsert_task(task) assert widget.rowCount() == 1 - assert widget.item(0, 1).text() == 'whisper-french.mp3' - assert widget.item(0, 2).text() == 'Queued' + assert widget.item(0, 1).text() == "whisper-french.mp3" + assert widget.item(0, 2).text() == "Queued" task.status = FileTranscriptionTask.Status.IN_PROGRESS task.fraction_completed = 0.3524 widget.upsert_task(task) assert widget.rowCount() == 1 - assert widget.item(0, 1).text() == 'whisper-french.mp3' - assert widget.item(0, 2).text() == 'In Progress (35%)' + assert widget.item(0, 1).text() == "whisper-french.mp3" + assert widget.item(0, 2).text() == "In Progress (35%)" task.status = FileTranscriptionTask.Status.COMPLETED task.completed_at = datetime.datetime(2023, 4, 12, 0, 0, 10) widget.upsert_task(task) assert widget.rowCount() == 1 - assert widget.item(0, 1).text() == 'whisper-french.mp3' - assert widget.item(0, 2).text() == 'Completed (5s)' + assert widget.item(0, 1).text() == "whisper-french.mp3" + assert widget.item(0, 2).text() == "Completed (5s)" def test_upsert_task_no_timings(self, qtbot: QtBot): widget = TranscriptionTasksTableWidget() qtbot.add_widget(widget) - task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3', - transcription_options=TranscriptionOptions(), - file_transcription_options=FileTranscriptionOptions( - file_paths=['testdata/whisper-french.mp3']), model_path='', - status=FileTranscriptionTask.Status.COMPLETED) + task = FileTranscriptionTask( + id=0, + file_path="testdata/whisper-french.mp3", + transcription_options=TranscriptionOptions(), + file_transcription_options=FileTranscriptionOptions( + file_paths=["testdata/whisper-french.mp3"] + ), + model_path="", + status=FileTranscriptionTask.Status.COMPLETED, + ) widget.upsert_task(task) assert widget.rowCount() == 1 - assert widget.item(0, 1).text() == 'whisper-french.mp3' - assert widget.item(0, 2).text() == 'Completed' + assert widget.item(0, 1).text() == "whisper-french.mp3" + assert widget.item(0, 2).text() == "Completed" diff --git a/tests/widgets/transcription_viewer_test.py b/tests/widgets/transcription_viewer_test.py index 278dee7a..33bb1622 100644 --- a/tests/widgets/transcription_viewer_test.py +++ b/tests/widgets/transcription_viewer_test.py @@ -6,69 +6,85 @@ from PyQt6.QtGui import QKeyEvent from PyQt6.QtWidgets import QPushButton, QToolBar, QToolButton from pytestqt.qtbot import QtBot -from buzz.transcriber import FileTranscriptionTask, FileTranscriptionOptions, TranscriptionOptions, Segment -from buzz.widgets.transcription_segments_editor_widget import TranscriptionSegmentsEditorWidget +from buzz.transcriber import ( + FileTranscriptionTask, + FileTranscriptionOptions, + TranscriptionOptions, + Segment, +) +from buzz.widgets.transcription_segments_editor_widget import ( + TranscriptionSegmentsEditorWidget, +) from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget class TestTranscriptionViewerWidget: @pytest.fixture() def task(self) -> FileTranscriptionTask: - return FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3', - file_transcription_options=FileTranscriptionOptions( - file_paths=['testdata/whisper-french.mp3']), - transcription_options=TranscriptionOptions(), - segments=[Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')], - model_path='') + return FileTranscriptionTask( + id=0, + file_path="testdata/whisper-french.mp3", + file_transcription_options=FileTranscriptionOptions( + file_paths=["testdata/whisper-french.mp3"] + ), + transcription_options=TranscriptionOptions(), + segments=[Segment(40, 299, "Bien"), Segment(299, 329, "venue dans")], + model_path="", + ) def test_should_display_segments(self, qtbot: QtBot, task): widget = TranscriptionViewerWidget( - transcription_task=task, open_transcription_output=False) + transcription_task=task, open_transcription_output=False + ) qtbot.add_widget(widget) - assert widget.windowTitle() == 'whisper-french.mp3' + assert widget.windowTitle() == "whisper-french.mp3" editor = widget.findChild(TranscriptionSegmentsEditorWidget) assert isinstance(editor, TranscriptionSegmentsEditorWidget) - assert editor.item(0, 0).text() == '00:00:00.040' - assert editor.item(0, 1).text() == '00:00:00.299' - assert editor.item(0, 2).text() == 'Bien' + assert editor.item(0, 0).text() == "00:00:00.040" + assert editor.item(0, 1).text() == "00:00:00.299" + assert editor.item(0, 2).text() == "Bien" def test_should_update_segment_text(self, qtbot, task): widget = TranscriptionViewerWidget( - transcription_task=task, open_transcription_output=False) + transcription_task=task, open_transcription_output=False + ) qtbot.add_widget(widget) editor = widget.findChild(TranscriptionSegmentsEditorWidget) assert isinstance(editor, TranscriptionSegmentsEditorWidget) # Change text - editor.item(0, 2).setText('Biens') - assert task.segments[0].text == 'Biens' + editor.item(0, 2).setText("Biens") + assert task.segments[0].text == "Biens" # Undo toolbar = widget.findChild(QToolBar) undo_action, redo_action = toolbar.actions() undo_action.trigger() - assert task.segments[0].text == 'Bien' + assert task.segments[0].text == "Bien" redo_action.trigger() - assert task.segments[0].text == 'Biens' + assert task.segments[0].text == "Biens" def test_should_export_segments(self, tmp_path: pathlib.Path, qtbot: QtBot, task): widget = TranscriptionViewerWidget( - transcription_task=task, open_transcription_output=False) + transcription_task=task, open_transcription_output=False + ) qtbot.add_widget(widget) export_button = widget.findChild(QPushButton) assert isinstance(export_button, QPushButton) - output_file_path = tmp_path / 'whisper.txt' - with patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock: - save_file_name_mock.return_value = (str(output_file_path), '') + output_file_path = tmp_path / "whisper.txt" + with patch( + "PyQt6.QtWidgets.QFileDialog.getSaveFileName" + ) as save_file_name_mock: + save_file_name_mock.return_value = (str(output_file_path), "") export_button.menu().actions()[0].trigger() - output_file = open(output_file_path, 'r', encoding='utf-8') - assert 'Bien\nvenue dans' in output_file.read() + output_file = open(output_file_path, "r", encoding="utf-8") + assert "Bien\nvenue dans" in output_file.read()