diff --git a/.coveragerc b/.coveragerc index 49b1269..9843aa0 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,4 +7,4 @@ omit = directory = coverage/html [report] -fail_under = 75 +fail_under = 78 diff --git a/Makefile b/Makefile index e7ea18a..06fc516 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ clean: rm -rf dist/* || true test: buzz/whisper_cpp.py - pytest --cov=buzz --cov-report=xml --cov-report=html + pytest -vv --cov=buzz --cov-report=xml --cov-report=html dist/Buzz dist/Buzz.app: buzz/whisper_cpp.py pyinstaller --noconfirm Buzz.spec diff --git a/buzz/cache.py b/buzz/cache.py index cd1d335..e1d4eb6 100644 --- a/buzz/cache.py +++ b/buzz/cache.py @@ -23,7 +23,7 @@ class TasksCache: return pickle.load(file) except FileNotFoundError: return [] - except pickle.UnpicklingError: # delete corrupted cache + except (pickle.UnpicklingError, AttributeError): # delete corrupted cache os.remove(self.file_path) return [] diff --git a/buzz/gui.py b/buzz/gui.py index f463f88..82f0df0 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -1,4 +1,7 @@ +import dataclasses import enum +import json +import logging import os import platform import random @@ -10,27 +13,28 @@ import humanize import sounddevice from PyQt6 import QtGui from PyQt6.QtCore import (QDateTime, QObject, Qt, QThread, - QTimer, QUrl, pyqtSignal, QModelIndex, QSize) + QTimer, QUrl, pyqtSignal, QModelIndex, QSize, QPoint, + QUrlQuery, QMetaObject, QEvent) from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon, - QKeySequence, QPixmap, QTextCursor, QValidator) + QKeySequence, QPixmap, QTextCursor, QValidator, QKeyEvent) +from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog, QDialogButtonBox, QFileDialog, QLabel, QLineEdit, QMainWindow, QMessageBox, QPlainTextEdit, QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu, QWidget, QGroupBox, QToolBar, QTableWidget, QMenuBar, QFormLayout, QTableWidgetItem, - QHeaderView, QAbstractItemView) + QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem) from requests import get from whisper import tokenizer from buzz.cache import TasksCache - from .__version__ import VERSION -from .model_loader import ModelLoader +from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat, RecordingTranscriber, Task, WhisperCppFileTranscriber, WhisperFileTranscriber, get_default_output_file_path, segments_to_text, write_output, TranscriptionOptions, - Model, FileTranscriberQueueWorker, FileTranscriptionTask) + FileTranscriberQueueWorker, FileTranscriptionTask) APP_NAME = 'Buzz' @@ -99,7 +103,7 @@ class AudioDevicesComboBox(QComboBox): class LanguagesComboBox(QComboBox): """LanguagesComboBox displays a list of languages available to use with Whisper""" - # language is a languge key from whisper.tokenizer.LANGUAGES or '' for "detect language" + # language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language" languageChanged = pyqtSignal(str) def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None: @@ -136,20 +140,6 @@ class TasksComboBox(QComboBox): self.taskChanged.emit(self.tasks[index]) -class ModelComboBox(QComboBox): - model_changed = pyqtSignal(Model) - - def __init__(self, default_model: Model, parent: Optional[QWidget], *args) -> None: - super().__init__(parent, *args) - self.models = [model for model in Model] - self.addItems([model.value for model in self.models]) - self.currentIndexChanged.connect(self.on_index_changed) - self.setCurrentText(default_model.value) - - def on_index_changed(self, index: int): - self.model_changed.emit(self.models[index]) - - class TextDisplayBox(QPlainTextEdit): """TextDisplayBox is a read-only textbox""" @@ -202,6 +192,7 @@ class DownloadModelProgressDialog(QProgressDialog): 'Cancel', 0, 100, parent, *args) self.setWindowModality(Qt.WindowModality.ApplicationModal) self.start_time = datetime.now() + self.setFixedSize(self.size()) def set_fraction_completed(self, fraction_completed: float) -> None: self.setValue(int(fraction_completed * self.maximum())) @@ -280,7 +271,7 @@ class TimerLabel(QLabel): def show_model_download_error_dialog(parent: QWidget, error: str): - message = f'Unable to load the Whisper model: {error}. Please retry or check the application logs for more ' \ + message = f"An error occurred while loading the Whisper model: {error}{'' if error.endswith('.') else '.'}" \ f'information. ' QMessageBox.critical(parent, '', message) @@ -302,7 +293,6 @@ class FileTranscriberWidget(QWidget): super().__init__(parent, flags) self.setWindowTitle(file_paths_as_title(file_paths)) - self.setFixedSize(420, 270) self.file_paths = file_paths self.transcription_options = TranscriptionOptions() @@ -332,15 +322,19 @@ class FileTranscriberWidget(QWidget): layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight) self.setLayout(layout) + self.setFixedSize(self.sizeHint()) def on_transcription_options_changed(self, transcription_options: TranscriptionOptions): self.transcription_options = transcription_options + self.word_level_timings_checkbox.setDisabled( + self.transcription_options.model.model_type == ModelType.HUGGING_FACE) def on_click_run(self): self.run_button.setDisabled(True) self.transcriber_thread = QThread() self.model_loader = ModelLoader(model=self.transcription_options.model) + self.model_loader.moveToThread(self.transcriber_thread) self.transcriber_thread.started.connect(self.model_loader.run) self.model_loader.finished.connect( @@ -378,6 +372,7 @@ class FileTranscriberWidget(QWidget): self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size) def on_download_model_error(self, error: str): + self.reset_model_download() show_model_download_error_dialog(self, error) self.reset_transcriber_controls() @@ -391,6 +386,7 @@ class FileTranscriberWidget(QWidget): def reset_model_download(self): if self.model_download_progress_dialog is not None: + self.model_download_progress_dialog.close() self.model_download_progress_dialog = None def on_word_level_timings_changed(self, value: int): @@ -477,9 +473,9 @@ class RecordingTranscriberWidget(QWidget): layout = QVBoxLayout(self) self.setWindowTitle('Live Recording') - self.setFixedSize(400, 520) - self.transcription_options = TranscriptionOptions(model=Model.WHISPER_CPP_TINY) + self.transcription_options = TranscriptionOptions(model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, + whisper_model_size=WhisperModelSize.TINY)) self.audio_devices_combo_box = AudioDevicesComboBox(self) self.audio_devices_combo_box.device_changed.connect( @@ -514,6 +510,7 @@ class RecordingTranscriberWidget(QWidget): layout.addWidget(self.text_box) self.setLayout(layout) + self.setFixedSize(self.sizeHint()) def closeEvent(self, event: QCloseEvent) -> None: self.stop_recording() @@ -534,8 +531,8 @@ class RecordingTranscriberWidget(QWidget): def start_recording(self): self.record_button.setDisabled(True) - use_whisper_cpp = self.transcription_options.model.is_whisper_cpp( - ) and self.transcription_options.language is not None + use_whisper_cpp = self.transcription_options.model.model_type == ModelType.WHISPER_CPP and \ + self.transcription_options.language is not None def start_recording_transcription(model_path: str): # Clear text box placeholder because the first chunk takes a while to process @@ -593,6 +590,7 @@ class RecordingTranscriberWidget(QWidget): self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size) def on_download_model_error(self, error: str): + self.reset_model_download() show_model_download_error_dialog(self, error) self.stop_recording() self.record_button.force_stop() @@ -620,6 +618,7 @@ class RecordingTranscriberWidget(QWidget): def reset_model_download(self): if self.model_download_progress_dialog is not None: + self.model_download_progress_dialog.close() self.model_download_progress_dialog = None @@ -964,6 +963,131 @@ class MainWindow(QMainWindow): super().closeEvent(event) +class LineEdit(QLineEdit): + def __init__(self, default_text: str = '', parent: Optional[QWidget] = None): + super().__init__(default_text, parent) + if platform.system() == 'Darwin': + self.setStyleSheet('QLineEdit { padding: 4px }') + + +# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py +class HuggingFaceSearchLineEdit(LineEdit): + model_selected = pyqtSignal(str) + popup: QListWidget + + def __init__(self, parent: Optional[QWidget] = None): + super().__init__('', parent) + + self.setMinimumWidth(150) + self.setPlaceholderText('openai/whisper-tiny') + + self.timer = QTimer(self) + self.timer.setSingleShot(True) + self.timer.setInterval(250) + self.timer.timeout.connect(self.fetch_models) + + # Restart debounce timer each time editor text changes + self.textEdited.connect(self.timer.start) + self.textEdited.connect(self.on_text_edited) + + self.network_manager = QNetworkAccessManager(self) + self.network_manager.finished.connect(self.on_request_response) + + self.popup = QListWidget() + self.popup.setWindowFlags(Qt.WindowType.Popup) + self.popup.setFocusProxy(self) + self.popup.setMouseTracking(True) + self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) + self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.popup.installEventFilter(self) + self.popup.itemClicked.connect(self.on_select_item) + + def on_text_edited(self, text: str): + self.model_selected.emit(text) + + def on_select_item(self): + self.popup.hide() + self.setFocus() + + item = self.popup.currentItem() + self.setText(item.text()) + QMetaObject.invokeMethod(self, 'returnPressed') + self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole)) + + def fetch_models(self): + text = self.text() + if len(text) < 3: + return + + url = QUrl("https://huggingface.co/api/models") + + query = QUrlQuery() + query.addQueryItem("filter", "whisper") + query.addQueryItem("search", text) + + url.setQuery(query) + + return self.network_manager.get(QNetworkRequest(url)) + + def on_popup_selected(self): + self.timer.stop() + + def on_request_response(self, network_reply: QNetworkReply): + if network_reply.error() != QNetworkReply.NetworkError.NoError: + logging.debug('Error fetching Hugging Face models: %s', network_reply.error()) + return + + models = json.loads(network_reply.readAll().data()) + + self.popup.setUpdatesEnabled(False) + self.popup.clear() + + for model in models: + model_id = model.get('id') + + item = QListWidgetItem(self.popup) + item.setText(model_id) + item.setData(Qt.ItemDataRole.UserRole, model_id) + + self.popup.setCurrentItem(self.popup.item(0)) + self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20) + self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll + self.popup.setUpdatesEnabled(True) + self.popup.move(self.mapToGlobal(QPoint(0, self.height()))) + self.popup.setFocus() + self.popup.show() + + def eventFilter(self, target: QObject, event: QEvent): + if hasattr(self, 'popup') is False or target != self.popup: + return False + + if event.type() == QEvent.Type.MouseButtonPress: + self.popup.hide() + self.setFocus() + return True + + if isinstance(event, QKeyEvent): + key = event.key() + if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]: + self.on_select_item() + return True + + if key == Qt.Key.Key_Escape: + self.setFocus() + self.popup.hide() + return True + + if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp, + Qt.Key.Key_PageDown]: + return False + + self.setFocus() + self.event(event) + self.popup.hide() + + return False + + class TranscriptionOptionsGroupBox(QGroupBox): transcription_options: TranscriptionOptions transcription_options_changed = pyqtSignal(TranscriptionOptions) @@ -972,7 +1096,7 @@ class TranscriptionOptionsGroupBox(QGroupBox): super().__init__(title='', parent=parent) self.transcription_options = default_transcription_options - layout = QFormLayout(self) + self.form_layout = QFormLayout(self) self.tasks_combo_box = TasksComboBox( default_task=self.transcription_options.task, @@ -985,30 +1109,41 @@ class TranscriptionOptionsGroupBox(QGroupBox): self.languages_combo_box.languageChanged.connect( self.on_language_changed) - self.model_combo_box = ModelComboBox( - default_model=self.transcription_options.model, - parent=self) - self.model_combo_box.model_changed.connect(self.on_model_changed) - self.advanced_settings_button = AdvancedSettingsButton(self) self.advanced_settings_button.clicked.connect( self.open_advanced_settings) - layout.addRow('Task:', self.tasks_combo_box) - layout.addRow('Language:', self.languages_combo_box) - layout.addRow('Model:', self.model_combo_box) - layout.addRow('', self.advanced_settings_button) + self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit() + self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed) - self.setLayout(layout) + model_type_combo_box = QComboBox(self) + model_type_combo_box.addItems([model_type.value for model_type in ModelType]) + model_type_combo_box.setCurrentText(default_transcription_options.model.model_type.value) + model_type_combo_box.currentTextChanged.connect(self.on_model_type_changed) + + self.whisper_model_size_combo_box = QComboBox(self) + self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize]) + if default_transcription_options.model.whisper_model_size is not None: + self.whisper_model_size_combo_box.setCurrentText( + default_transcription_options.model.whisper_model_size.value.title()) + self.whisper_model_size_combo_box.currentTextChanged.connect(self.on_whisper_model_size_changed) + + self.form_layout.addRow('Task:', self.tasks_combo_box) + self.form_layout.addRow('Language:', self.languages_combo_box) + self.form_layout.addRow('Model:', model_type_combo_box) + self.form_layout.addRow('', self.whisper_model_size_combo_box) + self.form_layout.addRow('', self.hugging_face_search_line_edit) + + self.form_layout.setRowVisible(self.hugging_face_search_line_edit, False) + + self.form_layout.addRow('', self.advanced_settings_button) + + self.setLayout(self.form_layout) def on_language_changed(self, language: str): self.transcription_options.language = language self.transcription_options_changed.emit(self.transcription_options) - def on_model_changed(self, model: Model): - self.transcription_options.model = model - self.transcription_options_changed.emit(self.transcription_options) - def on_task_changed(self, task: Task): self.transcription_options.task = task self.transcription_options_changed.emit(self.transcription_options) @@ -1032,6 +1167,23 @@ class TranscriptionOptionsGroupBox(QGroupBox): self.transcription_options = transcription_options self.transcription_options_changed.emit(transcription_options) + def on_model_type_changed(self, text: str): + model_type = ModelType(text) + self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE) + self.form_layout.setRowVisible(self.whisper_model_size_combo_box, + (model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP)) + self.transcription_options.model_type = model_type + self.transcription_options_changed.emit(self.transcription_options) + + def on_whisper_model_size_changed(self, text: str): + model_size = WhisperModelSize(text.lower()) + self.transcription_options.model.whisper_model_size = model_size + self.transcription_options_changed.emit(self.transcription_options) + + def on_hugging_face_model_changed(self, model: str): + self.transcription_options.hugging_face_model = model + self.transcription_options_changed.emit(self.transcription_options) + class MenuBar(QMenuBar): import_action_triggered = pyqtSignal() @@ -1080,7 +1232,6 @@ class AdvancedSettingsDialog(QDialog): self.transcription_options = transcription_options - self.setFixedSize(400, 180) self.setWindowTitle('Advanced Settings') button_box = QDialogButtonBox(QDialogButtonBox.StandardButton( @@ -1091,27 +1242,28 @@ class AdvancedSettingsDialog(QDialog): default_temperature_text = ', '.join( [str(temp) for temp in transcription_options.temperature]) - self.temperature_line_edit = QLineEdit(default_temperature_text, self) + self.temperature_line_edit = LineEdit(default_temperature_text, self) self.temperature_line_edit.setPlaceholderText( 'Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"') + self.temperature_line_edit.setMinimumWidth(170) self.temperature_line_edit.textChanged.connect( self.on_temperature_changed) self.temperature_line_edit.setValidator(TemperatureValidator(self)) - self.temperature_line_edit.setDisabled( - transcription_options.model.is_whisper_cpp()) + self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER) self.initial_prompt_text_edit = QPlainTextEdit( transcription_options.initial_prompt, self) self.initial_prompt_text_edit.textChanged.connect( self.on_initial_prompt_changed) - self.initial_prompt_text_edit.setDisabled( - transcription_options.model.is_whisper_cpp()) + self.initial_prompt_text_edit.setEnabled( + transcription_options.model.model_type == ModelType.WHISPER) layout.addRow('Temperature:', self.temperature_line_edit) layout.addRow('Initial Prompt:', self.initial_prompt_text_edit) layout.addWidget(button_box) self.setLayout(layout) + self.setFixedSize(self.sizeHint()) def on_temperature_changed(self, text: str): try: diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 40ac120..9350052 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -1,7 +1,9 @@ +import enum import hashlib import logging import os import warnings +from dataclasses import dataclass from typing import Optional import requests @@ -9,9 +11,31 @@ import whisper from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot from platformdirs import user_cache_dir -from buzz.transcriber import Model +from buzz import transformers_whisper -MODELS_SHA256 = { + +class WhisperModelSize(enum.Enum): + TINY = 'tiny' + BASE = 'base' + SMALL = 'small' + MEDIUM = 'medium' + LARGE = 'large' + + +class ModelType(enum.Enum): + WHISPER = 'Whisper' + WHISPER_CPP = 'Whisper.cpp' + HUGGING_FACE = 'Hugging Face' + + +@dataclass() +class TranscriptionModel: + model_type: ModelType = ModelType.WHISPER + whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY + hugging_face_model_id: Optional[str] = None + + +WHISPER_CPP_MODELS_SHA256 = { 'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21', 'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe', 'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b', @@ -20,53 +44,85 @@ MODELS_SHA256 = { } +def get_hugging_face_dataset_file_url(author: str, repository_name: str, filename: str): + return f'https://huggingface.co/datasets/{author}/{repository_name}/resolve/main/{filename}' + + class ModelLoader(QObject): progress = pyqtSignal(tuple) # (current, total) finished = pyqtSignal(str) error = pyqtSignal(str) stopped = False - def __init__(self, model: Model, parent: Optional['QObject'] = None) -> None: + def __init__(self, model: TranscriptionModel, parent: Optional['QObject'] = None) -> None: super().__init__(parent) - self.name = model.model_name() - self.use_whisper_cpp = model.is_whisper_cpp() + self.model_type = model.model_type + self.whisper_model_size = model.whisper_model_size + self.hugging_face_model_id = model.hugging_face_model_id @pyqtSlot() def run(self): + if self.model_type == ModelType.WHISPER_CPP: + root_dir = user_cache_dir('Buzz') + model_name = self.whisper_model_size.value + url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp', + filename=f'ggml-{model_name}.bin') + file_path = os.path.join(root_dir, f'ggml-model-whisper-{model_name}.bin') + expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name] + self.download_model(url, file_path, expected_sha256) + return + + if self.model_type == ModelType.WHISPER: + root_dir = os.getenv( + "XDG_CACHE_HOME", + os.path.join(os.path.expanduser("~"), ".cache", "whisper") + ) + model_name = self.whisper_model_size.value + url = whisper._MODELS[model_name] + file_path = os.path.join(root_dir, os.path.basename(url)) + expected_sha256 = url.split('/')[-2] + self.download_model(url, file_path, expected_sha256) + return + + if self.model_type == ModelType.HUGGING_FACE: + self.progress.emit((0, 100)) + + try: + # Loads the model from cache or download if not in cache + transformers_whisper.load_model(self.hugging_face_model_id) + except (FileNotFoundError, EnvironmentError) as exception: + self.error.emit(f'{exception}') + return + + self.progress.emit((100, 100)) + self.finished.emit(self.hugging_face_model_id) + return + + def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]): try: - if self.use_whisper_cpp: - root = user_cache_dir('Buzz') - url = f'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-{self.name}.bin' - model_path = os.path.join(root, f'ggml-model-whisper-{self.name}.bin') - else: - root = os.getenv( - "XDG_CACHE_HOME", - os.path.join(os.path.expanduser("~"), ".cache", "whisper") - ) - url = whisper._MODELS[self.name] - model_path = os.path.join(root, os.path.basename(url)) + os.makedirs(os.path.dirname(file_path), exist_ok=True) - os.makedirs(root, exist_ok=True) - - if os.path.exists(model_path) and not os.path.isfile(model_path): + if os.path.exists(file_path) and not os.path.isfile(file_path): raise RuntimeError( - f"{model_path} exists and is not a regular file") + f"{file_path} exists and is not a regular file") - expected_sha256 = MODELS_SHA256[self.name] if self.use_whisper_cpp else url.split( - "/")[-2] - if os.path.isfile(model_path): - model_bytes = open(model_path, "rb").read() + if os.path.isfile(file_path): + if expected_sha256 is None: + self.finished.emit(file_path) + return + + model_bytes = open(file_path, "rb").read() model_sha256 = hashlib.sha256(model_bytes).hexdigest() if model_sha256 == expected_sha256: - self.finished.emit(model_path) + self.finished.emit(file_path) return else: warnings.warn( - f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file") + f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file") # Downloads the model using the requests module instead of urllib to # use the certs from certifi when the app is running in frozen mode - with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output: + with requests.get(url, stream=True, timeout=15) as source, open(file_path, 'wb') as output: source.raise_for_status() total_size = float(source.headers.get('Content-Length', 0)) current = 0.0 @@ -78,12 +134,14 @@ class ModelLoader(QObject): current += len(chunk) self.progress.emit((current, total_size)) - model_bytes = open(model_path, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: - raise RuntimeError( - "Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.") + if expected_sha256 is not None: + model_bytes = open(file_path, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not match. Please retry loading the " + "model.") - self.finished.emit(model_path) + self.finished.emit(file_path) except RuntimeError as exc: self.error.emit(str(exc)) logging.exception('') diff --git a/buzz/transcriber.py b/buzz/transcriber.py index caf2888..826a2cf 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -15,7 +15,6 @@ from dataclasses import dataclass, field from multiprocessing.connection import Connection from threading import Thread from typing import Any, Callable, List, Optional, Tuple, Union -import typing import ffmpeg import numpy as np @@ -25,7 +24,9 @@ import whisper from PyQt6.QtCore import QObject, QProcess, pyqtSignal, pyqtSlot, QThread from sounddevice import PortAudioError +from . import transformers_whisper from .conn import pipe_stderr +from .model_loader import TranscriptionModel, ModelType # Catch exception from whisper.dll not getting loaded. # TODO: Remove flag and try-except when issue with loading @@ -53,32 +54,11 @@ class Segment: text: str -class Model(enum.Enum): - WHISPER_TINY = 'Whisper - Tiny' - WHISPER_BASE = 'Whisper - Base' - WHISPER_SMALL = 'Whisper - Small' - WHISPER_MEDIUM = 'Whisper - Medium' - WHISPER_LARGE = 'Whisper - Large' - WHISPER_CPP_TINY = 'Whisper.cpp - Tiny' - WHISPER_CPP_BASE = 'Whisper.cpp - Base' - WHISPER_CPP_SMALL = 'Whisper.cpp - Small' - WHISPER_CPP_MEDIUM = 'Whisper.cpp - Medium' - WHISPER_CPP_LARGE = 'Whisper.cpp - Large' - - def is_whisper_cpp(self) -> bool: - model_type, _ = self.value.split(' - ') - return model_type == 'Whisper.cpp' - - def model_name(self) -> str: - _, model_name = self.value.split(' - ') - return model_name.lower() - - @dataclass() class TranscriptionOptions: language: Optional[str] = None task: Task = Task.TRANSCRIBE - model: Model = Model.WHISPER_TINY + model: TranscriptionModel = TranscriptionModel() word_level_timings: bool = False temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE initial_prompt: str = '' @@ -373,7 +353,7 @@ class WhisperFileTranscriber(QObject): current_process: multiprocessing.Process progress = pyqtSignal(tuple) # (current, total) - completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment]) + completed = pyqtSignal(list) # List[Segment] error = pyqtSignal(str) running = False read_line_thread: Optional[Thread] = None @@ -390,6 +370,8 @@ class WhisperFileTranscriber(QObject): self.temperature = task.transcription_options.temperature self.initial_prompt = task.transcription_options.initial_prompt self.model_path = task.model_path + self.transcription_options = task.transcription_options + self.transcription_task = task self.segments = [] @pyqtSlot() @@ -406,13 +388,8 @@ class WhisperFileTranscriber(QObject): recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False) - self.current_process = multiprocessing.Process( - target=transcribe_whisper, - args=( - send_pipe, model_path, self.file_path, - self.language, self.task, self.word_level_timings, - self.temperature, self.initial_prompt - )) + self.current_process = multiprocessing.Process(target=transcribe_whisper, + args=(send_pipe, self.transcription_task)) self.current_process.start() self.read_line_thread = Thread( @@ -430,8 +407,9 @@ class WhisperFileTranscriber(QObject): self.read_line_thread.join() - if self.current_process.exitcode != 0: - self.completed.emit((self.current_process.exitcode, [])) + # TODO: fix error handling when process crashes + if self.current_process.exitcode != 0 and self.current_process.exitcode is not None: + self.completed.emit([]) self.running = False @@ -458,7 +436,7 @@ class WhisperFileTranscriber(QObject): ) for segment in segments_dict] self.current_process.join() # TODO: move this back to the parent thread - self.completed.emit((self.current_process.exitcode, segments)) + self.completed.emit(segments) else: try: progress = int(line.split('|')[0].strip().strip('%')) @@ -467,26 +445,30 @@ class WhisperFileTranscriber(QObject): continue -def transcribe_whisper( - stderr_conn: Connection, model_path: str, file_path: str, - language: Optional[str], task: Task, - word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str): +def transcribe_whisper(stderr_conn: Connection, task: FileTranscriptionTask): with pipe_stderr(stderr_conn): - model = whisper.load_model(model_path) - - if word_level_timings: - stable_whisper.modify_model(model) - result = model.transcribe( - audio=file_path, language=language, - task=task.value, temperature=temperature, - initial_prompt=initial_prompt, pbar=True) + if task.transcription_options.model.model_type == ModelType.HUGGING_FACE: + model = transformers_whisper.load_model(task.model_path) + language = task.transcription_options.language if task.transcription_options.language is not None else 'en' + result = model.transcribe(audio_path=task.file_path, language=language, + task=task.transcription_options.task.value, verbose=False) + whisper_segments = result.get('segments') else: - result = model.transcribe( - audio=file_path, language=language, task=task.value, temperature=temperature, - initial_prompt=initial_prompt, verbose=False) - - whisper_segments = stable_whisper.group_word_timestamps( - result) if word_level_timings else result.get('segments') + model = whisper.load_model(task.model_path) + if task.transcription_options.word_level_timings: + stable_whisper.modify_model(model) + result = model.transcribe( + audio=task.file_path, language=task.transcription_options.language, + task=task.transcription_options.task.value, temperature=task.transcription_options.temperature, + initial_prompt=task.transcription_options.initial_prompt, pbar=True) + whisper_segments = stable_whisper.group_word_timestamps(result) + else: + result = model.transcribe( + audio=task.file_path, language=task.transcription_options.language, + task=task.transcription_options.task.value, + temperature=task.transcription_options.temperature, + initial_prompt=task.transcription_options.initial_prompt, verbose=False) + whisper_segments = result.get('segments') segments = [ Segment( @@ -638,7 +620,7 @@ class FileTranscriberQueueWorker(QObject): self.completed.emit() return - if self.current_task.transcription_options.model.is_whisper_cpp(): + if self.current_task.transcription_options.model.model_type == ModelType.WHISPER_CPP: self.current_transcriber = WhisperCppFileTranscriber( task=self.current_task) else: @@ -688,10 +670,9 @@ class FileTranscriberQueueWorker(QObject): self.current_task.fraction_completed = progress[0] / progress[1] self.task_updated.emit(self.current_task) - @pyqtSlot(tuple) - def on_task_completed(self, result: Tuple[int, List[Segment]]): + @pyqtSlot(list) + def on_task_completed(self, segments: List[Segment]): if self.current_task is not None: - _, segments = result self.current_task.status = FileTranscriptionTask.Status.COMPLETED self.current_task.segments = segments self.task_updated.emit(self.current_task) diff --git a/buzz/transformers_whisper.py b/buzz/transformers_whisper.py new file mode 100644 index 0000000..672e88c --- /dev/null +++ b/buzz/transformers_whisper.py @@ -0,0 +1,56 @@ +from typing import Optional + +import numpy as np +import whisper +from tqdm import tqdm +from transformers import WhisperProcessor, WhisperForConditionalGeneration + + +def load_model(model_name_or_path: str): + processor = WhisperProcessor.from_pretrained(model_name_or_path) + model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path) + return TransformersWhisper(processor, model) + + +class TransformersWhisper: + SAMPLE_RATE = whisper.audio.SAMPLE_RATE + N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES + + def __init__(self, processor: WhisperProcessor, model: WhisperForConditionalGeneration): + self.processor = processor + self.model = model + + # Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and + # timestamps are available. See: https://github.com/huggingface/transformers/issues/19887, + # https://github.com/huggingface/transformers/pull/20620. + def transcribe(self, audio_path: str, language: str, task: str, verbose: Optional[bool] = None): + audio: np.ndarray = whisper.load_audio(audio_path, sr=self.SAMPLE_RATE) + self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(task=task, language=language) + + segments = [] + all_predicted_ids = [] + + num_samples = audio.size + seek = 0 + with tqdm(total=num_samples, unit='samples', disable=verbose is not False) as progress_bar: + while seek < num_samples: + chunk = audio[seek: seek + self.N_SAMPLES_IN_CHUNK] + input_features = self.processor(chunk, return_tensors="pt", + sampling_rate=self.SAMPLE_RATE).input_features + predicted_ids = self.model.generate(input_features) + all_predicted_ids.extend(predicted_ids) + text: str = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + if text.strip() != '': + segments.append({ + 'start': seek / self.SAMPLE_RATE, + 'end': min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) / self.SAMPLE_RATE, + 'text': text + }) + + progress_bar.update(min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek) + seek += self.N_SAMPLES_IN_CHUNK + + return { + 'text': self.processor.batch_decode(all_predicted_ids, skip_special_tokens=True)[0], + 'segments': segments + } diff --git a/tests/gui_test.py b/tests/gui_test.py index 886985d..4872c8c 100644 --- a/tests/gui_test.py +++ b/tests/gui_test.py @@ -6,19 +6,19 @@ from unittest.mock import Mock, patch import pytest import sounddevice from PyQt6.QtCore import QSize, Qt -from PyQt6.QtGui import QValidator -from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget +from PyQt6.QtGui import QValidator, QKeyEvent +from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget, QApplication from pytestqt.qtbot import QtBot from buzz.cache import TasksCache from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application, AudioDevicesComboBox, DownloadModelProgressDialog, FileTranscriberWidget, LanguagesComboBox, MainWindow, - ModelComboBox, RecordingTranscriberWidget, + RecordingTranscriberWidget, TemperatureValidator, TextDisplayBox, - TranscriptionTasksTableWidget, TranscriptionViewerWidget) + TranscriptionTasksTableWidget, TranscriptionViewerWidget, HuggingFaceSearchLineEdit) from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, - Model, Segment, TranscriptionOptions) + Segment, TranscriptionOptions) class TestApplication: @@ -48,20 +48,6 @@ class TestLanguagesComboBox: assert languages_combo_box.currentText() == 'Detect Language' -class TestModelComboBox: - model_combo_box = ModelComboBox( - default_model=Model.WHISPER_CPP_BASE, parent=None) - - def test_should_show_qualities(self): - assert self.model_combo_box.itemText(0) == 'Whisper - Tiny' - assert self.model_combo_box.itemText(1) == 'Whisper - Base' - assert self.model_combo_box.itemText(2) == 'Whisper - Small' - assert self.model_combo_box.itemText(3) == 'Whisper - Medium' - - def test_should_select_default_model(self): - assert self.model_combo_box.currentText() == 'Whisper.cpp - Base' - - class TestAudioDevicesComboBox: mock_query_devices = [ {'name': 'Background Music', 'index': 0, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2, @@ -198,10 +184,9 @@ class TestFileTranscriberWidget: widget = FileTranscriberWidget( file_paths=['testdata/whisper-french.mp3'], parent=None) - def test_should_set_window_title_and_size(self, qtbot: QtBot): + def test_should_set_window_title(self, qtbot: QtBot): qtbot.addWidget(self.widget) assert self.widget.windowTitle() == 'whisper-french.mp3' - assert self.widget.size() == QSize(420, 270) def test_should_emit_triggered_event(self, qtbot: QtBot): widget = FileTranscriberWidget( @@ -217,7 +202,6 @@ class TestFileTranscriberWidget: transcription_options, file_transcription_options, model_path = mock_triggered.call_args[ 0][0] assert transcription_options.language is None - assert transcription_options.model == Model.WHISPER_TINY assert file_transcription_options.file_paths == [ 'testdata/whisper-french.mp3'] assert len(model_path) > 0 @@ -232,8 +216,7 @@ class TestAboutDialog: class TestAdvancedSettingsDialog: def test_should_update_advanced_settings(self, qtbot: QtBot): dialog = AdvancedSettingsDialog( - transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt', - model=Model.WHISPER_CPP_BASE)) + transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt')) qtbot.add_widget(dialog) transcription_options_mock = Mock() @@ -333,7 +316,50 @@ class TestTranscriptionTasksTableWidget: class TestRecordingTranscriberWidget: widget = RecordingTranscriberWidget() - def test_should_set_window_title_and_size(self, qtbot: QtBot): + def test_should_set_window_title(self, qtbot: QtBot): qtbot.add_widget(self.widget) assert self.widget.windowTitle() == 'Live Recording' - assert self.widget.size() == QSize(400, 520) + + +class TestHuggingFaceSearchLineEdit: + def test_should_update_selected_model_on_type(self, qtbot: QtBot): + widget = HuggingFaceSearchLineEdit() + qtbot.add_widget(widget) + + mock_model_selected = Mock() + widget.model_selected.connect(mock_model_selected) + + self._set_text_and_wait_response(qtbot, widget) + mock_model_selected.assert_called_with('openai/whisper-tiny') + + def test_should_show_list_of_models(self, qtbot: QtBot): + widget = HuggingFaceSearchLineEdit() + qtbot.add_widget(widget) + + self._set_text_and_wait_response(qtbot, widget) + + assert widget.popup.count() > 0 + assert 'openai/whisper-tiny' in widget.popup.item(0).text() + + def test_should_select_model_from_list(self, qtbot: QtBot): + widget = HuggingFaceSearchLineEdit() + qtbot.add_widget(widget) + + mock_model_selected = Mock() + widget.model_selected.connect(mock_model_selected) + + self._set_text_and_wait_response(qtbot, widget) + + # press down arrow and enter to select next item + QApplication.sendEvent(widget.popup, + QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier)) + QApplication.sendEvent(widget.popup, + QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Enter, Qt.KeyboardModifier.NoModifier)) + + mock_model_selected.assert_called_with('openai/whisper-tiny.en') + + @staticmethod + def _set_text_and_wait_response(qtbot: QtBot, widget: HuggingFaceSearchLineEdit): + with qtbot.wait_signal(widget.network_manager.finished, timeout=30 * 1000): + widget.setText('openai/whisper-tiny') + widget.textEdited.emit('openai/whisper-tiny') diff --git a/tests/model_loader.py b/tests/model_loader.py new file mode 100644 index 0000000..1dc27c0 --- /dev/null +++ b/tests/model_loader.py @@ -0,0 +1,14 @@ +from buzz.model_loader import ModelLoader, TranscriptionModel + + +def get_model_path(transcription_model: TranscriptionModel) -> str: + model_loader = ModelLoader(model=transcription_model) + model_path = '' + + def on_load_model(path: str): + nonlocal model_path + model_path = path + + model_loader.finished.connect(on_load_model) + model_loader.run() + return model_path diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index 28e62c6..df2f204 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -8,30 +8,18 @@ from unittest.mock import Mock import pytest from pytestqt.qtbot import QtBot -from buzz.model_loader import ModelLoader -from buzz.transcriber import (FileTranscriberQueueWorker, FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber, Segment, Task, - WhisperCpp, WhisperCppFileTranscriber, +from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel +from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber, + Segment, Task, WhisperCpp, WhisperCppFileTranscriber, WhisperFileTranscriber, get_default_output_file_path, to_timestamp, - whisper_cpp_params, write_output, TranscriptionOptions, Model) - - -def get_model_path(model: Model) -> str: - model_loader = ModelLoader(model=model) - model_path = '' - - def on_load_model(path: str): - nonlocal model_path - model_path = path - - model_loader.finished.connect(on_load_model) - model_loader.run() - return model_path + whisper_cpp_params, write_output, TranscriptionOptions) +from tests.model_loader import get_model_path class TestRecordingTranscriber: def test_transcriber(self): - model_path = get_model_path(Model.WHISPER_CPP_TINY) + model_path = get_model_path(transcription_model=TranscriptionModel()) transcriber = RecordingTranscriber( model_path=model_path, use_whisper_cpp=True, language='en', task=Task.TRANSCRIBE) @@ -49,11 +37,15 @@ class TestWhisperCppFileTranscriber: file_transcription_options = FileTranscriptionOptions( file_paths=['testdata/whisper-french.mp3']) transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE, - word_level_timings=word_level_timings) + word_level_timings=word_level_timings, + model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, + whisper_model_size=WhisperModelSize.TINY)) - model_path = get_model_path(Model.WHISPER_CPP_TINY) + model_path = get_model_path(transcription_options.model) transcriber = WhisperCppFileTranscriber( - task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3', transcription_options=transcription_options, file_transcription_options=file_transcription_options, model_path=model_path)) + task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3', + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, model_path=model_path)) mock_progress = Mock() mock_completed = Mock() transcriber.progress.connect(mock_progress) @@ -81,44 +73,51 @@ class TestWhisperFileTranscriber: assert srt.endswith('.srt') @pytest.mark.parametrize( - 'word_level_timings,expected_segments', + 'word_level_timings,expected_segments,model,check_progress', [ - (False, [ - Segment( - 0, 6560, - ' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances'), - ]), - (True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')]) + (False, [Segment(0, 6560, + ' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances')], + TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True), + (True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')], + TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True), + (False, [Segment(0, 8517, + ' Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances ' + 'et des apprenances de français.')], + TranscriptionModel(model_type=ModelType.HUGGING_FACE, + hugging_face_model_id='openai/whisper-tiny'), False) ]) - def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]): - model_path = get_model_path(Model.WHISPER_TINY) - + def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment], + model: TranscriptionModel, check_progress): mock_progress = Mock() mock_completed = Mock() transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE, - word_level_timings=word_level_timings) + word_level_timings=word_level_timings, + model=model) + model_path = get_model_path(transcription_options.model) file_transcription_options = FileTranscriptionOptions( file_paths=['testdata/whisper-french.mp3']) transcriber = WhisperFileTranscriber( - task=FileTranscriptionTask(transcription_options=transcription_options, file_transcription_options=file_transcription_options, file_path='testdata/whisper-french.mp3', model_path=model_path)) + task=FileTranscriptionTask(transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + file_path='testdata/whisper-french.mp3', model_path=model_path)) transcriber.progress.connect(mock_progress) transcriber.completed.connect(mock_completed) with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000): transcriber.run() - # Reports progress at 0, 0 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'), + OutputFormat.SRT, + '1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'), (OutputFormat.VTT, 'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'), ]) diff --git a/tests/transformers_whisper_test.py b/tests/transformers_whisper_test.py new file mode 100644 index 0000000..1d93307 --- /dev/null +++ b/tests/transformers_whisper_test.py @@ -0,0 +1,10 @@ +from buzz.transformers_whisper import load_model + + +class TestTransformersWhisper: + def test_should_transcribe(self): + model = load_model('openai/whisper-tiny') + result = model.transcribe( + audio_path='testdata/whisper-french.mp3', language='fr', task='transcribe') + + assert 'Bienvenue dans Passe' in result['text']