Add support for Hugging Face models (#264)

This commit is contained in:
Chidi Williams 2022-12-26 12:48:45 +00:00 committed by GitHub
parent 82bdd30fb8
commit 3dceb11c4d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 515 additions and 214 deletions

View file

@ -7,4 +7,4 @@ omit =
directory = coverage/html
[report]
fail_under = 75
fail_under = 78

View file

@ -43,7 +43,7 @@ clean:
rm -rf dist/* || true
test: buzz/whisper_cpp.py
pytest --cov=buzz --cov-report=xml --cov-report=html
pytest -vv --cov=buzz --cov-report=xml --cov-report=html
dist/Buzz dist/Buzz.app: buzz/whisper_cpp.py
pyinstaller --noconfirm Buzz.spec

View file

@ -23,7 +23,7 @@ class TasksCache:
return pickle.load(file)
except FileNotFoundError:
return []
except pickle.UnpicklingError: # delete corrupted cache
except (pickle.UnpicklingError, AttributeError): # delete corrupted cache
os.remove(self.file_path)
return []

View file

@ -1,4 +1,7 @@
import dataclasses
import enum
import json
import logging
import os
import platform
import random
@ -10,27 +13,28 @@ import humanize
import sounddevice
from PyQt6 import QtGui
from PyQt6.QtCore import (QDateTime, QObject, Qt, QThread,
QTimer, QUrl, pyqtSignal, QModelIndex, QSize)
QTimer, QUrl, pyqtSignal, QModelIndex, QSize, QPoint,
QUrlQuery, QMetaObject, QEvent)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor, QValidator)
QKeySequence, QPixmap, QTextCursor, QValidator, QKeyEvent)
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QDialogButtonBox, QFileDialog, QLabel, QLineEdit,
QMainWindow, QMessageBox, QPlainTextEdit,
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu,
QWidget, QGroupBox, QToolBar, QTableWidget, QMenuBar, QFormLayout, QTableWidgetItem,
QHeaderView, QAbstractItemView)
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem)
from requests import get
from whisper import tokenizer
from buzz.cache import TasksCache
from .__version__ import VERSION
from .model_loader import ModelLoader
from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
RecordingTranscriber, Task,
WhisperCppFileTranscriber, WhisperFileTranscriber,
get_default_output_file_path, segments_to_text, write_output, TranscriptionOptions,
Model, FileTranscriberQueueWorker, FileTranscriptionTask)
FileTranscriberQueueWorker, FileTranscriptionTask)
APP_NAME = 'Buzz'
@ -99,7 +103,7 @@ class AudioDevicesComboBox(QComboBox):
class LanguagesComboBox(QComboBox):
"""LanguagesComboBox displays a list of languages available to use with Whisper"""
# language is a languge key from whisper.tokenizer.LANGUAGES or '' for "detect language"
# language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language"
languageChanged = pyqtSignal(str)
def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None:
@ -136,20 +140,6 @@ class TasksComboBox(QComboBox):
self.taskChanged.emit(self.tasks[index])
class ModelComboBox(QComboBox):
model_changed = pyqtSignal(Model)
def __init__(self, default_model: Model, parent: Optional[QWidget], *args) -> None:
super().__init__(parent, *args)
self.models = [model for model in Model]
self.addItems([model.value for model in self.models])
self.currentIndexChanged.connect(self.on_index_changed)
self.setCurrentText(default_model.value)
def on_index_changed(self, index: int):
self.model_changed.emit(self.models[index])
class TextDisplayBox(QPlainTextEdit):
"""TextDisplayBox is a read-only textbox"""
@ -202,6 +192,7 @@ class DownloadModelProgressDialog(QProgressDialog):
'Cancel', 0, 100, parent, *args)
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.start_time = datetime.now()
self.setFixedSize(self.size())
def set_fraction_completed(self, fraction_completed: float) -> None:
self.setValue(int(fraction_completed * self.maximum()))
@ -280,7 +271,7 @@ class TimerLabel(QLabel):
def show_model_download_error_dialog(parent: QWidget, error: str):
message = f'Unable to load the Whisper model: {error}. Please retry or check the application logs for more ' \
message = f"An error occurred while loading the Whisper model: {error}{'' if error.endswith('.') else '.'}" \
f'information. '
QMessageBox.critical(parent, '', message)
@ -302,7 +293,6 @@ class FileTranscriberWidget(QWidget):
super().__init__(parent, flags)
self.setWindowTitle(file_paths_as_title(file_paths))
self.setFixedSize(420, 270)
self.file_paths = file_paths
self.transcription_options = TranscriptionOptions()
@ -332,15 +322,19 @@ class FileTranscriberWidget(QWidget):
layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout)
self.setFixedSize(self.sizeHint())
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
self.transcription_options = transcription_options
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE)
def on_click_run(self):
self.run_button.setDisabled(True)
self.transcriber_thread = QThread()
self.model_loader = ModelLoader(model=self.transcription_options.model)
self.model_loader.moveToThread(self.transcriber_thread)
self.transcriber_thread.started.connect(self.model_loader.run)
self.model_loader.finished.connect(
@ -378,6 +372,7 @@ class FileTranscriberWidget(QWidget):
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
def on_download_model_error(self, error: str):
self.reset_model_download()
show_model_download_error_dialog(self, error)
self.reset_transcriber_controls()
@ -391,6 +386,7 @@ class FileTranscriberWidget(QWidget):
def reset_model_download(self):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None
def on_word_level_timings_changed(self, value: int):
@ -477,9 +473,9 @@ class RecordingTranscriberWidget(QWidget):
layout = QVBoxLayout(self)
self.setWindowTitle('Live Recording')
self.setFixedSize(400, 520)
self.transcription_options = TranscriptionOptions(model=Model.WHISPER_CPP_TINY)
self.transcription_options = TranscriptionOptions(model=TranscriptionModel(model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY))
self.audio_devices_combo_box = AudioDevicesComboBox(self)
self.audio_devices_combo_box.device_changed.connect(
@ -514,6 +510,7 @@ class RecordingTranscriberWidget(QWidget):
layout.addWidget(self.text_box)
self.setLayout(layout)
self.setFixedSize(self.sizeHint())
def closeEvent(self, event: QCloseEvent) -> None:
self.stop_recording()
@ -534,8 +531,8 @@ class RecordingTranscriberWidget(QWidget):
def start_recording(self):
self.record_button.setDisabled(True)
use_whisper_cpp = self.transcription_options.model.is_whisper_cpp(
) and self.transcription_options.language is not None
use_whisper_cpp = self.transcription_options.model.model_type == ModelType.WHISPER_CPP and \
self.transcription_options.language is not None
def start_recording_transcription(model_path: str):
# Clear text box placeholder because the first chunk takes a while to process
@ -593,6 +590,7 @@ class RecordingTranscriberWidget(QWidget):
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
def on_download_model_error(self, error: str):
self.reset_model_download()
show_model_download_error_dialog(self, error)
self.stop_recording()
self.record_button.force_stop()
@ -620,6 +618,7 @@ class RecordingTranscriberWidget(QWidget):
def reset_model_download(self):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None
@ -964,6 +963,131 @@ class MainWindow(QMainWindow):
super().closeEvent(event)
class LineEdit(QLineEdit):
def __init__(self, default_text: str = '', parent: Optional[QWidget] = None):
super().__init__(default_text, parent)
if platform.system() == 'Darwin':
self.setStyleSheet('QLineEdit { padding: 4px }')
# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py
class HuggingFaceSearchLineEdit(LineEdit):
model_selected = pyqtSignal(str)
popup: QListWidget
def __init__(self, parent: Optional[QWidget] = None):
super().__init__('', parent)
self.setMinimumWidth(150)
self.setPlaceholderText('openai/whisper-tiny')
self.timer = QTimer(self)
self.timer.setSingleShot(True)
self.timer.setInterval(250)
self.timer.timeout.connect(self.fetch_models)
# Restart debounce timer each time editor text changes
self.textEdited.connect(self.timer.start)
self.textEdited.connect(self.on_text_edited)
self.network_manager = QNetworkAccessManager(self)
self.network_manager.finished.connect(self.on_request_response)
self.popup = QListWidget()
self.popup.setWindowFlags(Qt.WindowType.Popup)
self.popup.setFocusProxy(self)
self.popup.setMouseTracking(True)
self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.popup.installEventFilter(self)
self.popup.itemClicked.connect(self.on_select_item)
def on_text_edited(self, text: str):
self.model_selected.emit(text)
def on_select_item(self):
self.popup.hide()
self.setFocus()
item = self.popup.currentItem()
self.setText(item.text())
QMetaObject.invokeMethod(self, 'returnPressed')
self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole))
def fetch_models(self):
text = self.text()
if len(text) < 3:
return
url = QUrl("https://huggingface.co/api/models")
query = QUrlQuery()
query.addQueryItem("filter", "whisper")
query.addQueryItem("search", text)
url.setQuery(query)
return self.network_manager.get(QNetworkRequest(url))
def on_popup_selected(self):
self.timer.stop()
def on_request_response(self, network_reply: QNetworkReply):
if network_reply.error() != QNetworkReply.NetworkError.NoError:
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
return
models = json.loads(network_reply.readAll().data())
self.popup.setUpdatesEnabled(False)
self.popup.clear()
for model in models:
model_id = model.get('id')
item = QListWidgetItem(self.popup)
item.setText(model_id)
item.setData(Qt.ItemDataRole.UserRole, model_id)
self.popup.setCurrentItem(self.popup.item(0))
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
self.popup.setUpdatesEnabled(True)
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
self.popup.setFocus()
self.popup.show()
def eventFilter(self, target: QObject, event: QEvent):
if hasattr(self, 'popup') is False or target != self.popup:
return False
if event.type() == QEvent.Type.MouseButtonPress:
self.popup.hide()
self.setFocus()
return True
if isinstance(event, QKeyEvent):
key = event.key()
if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]:
self.on_select_item()
return True
if key == Qt.Key.Key_Escape:
self.setFocus()
self.popup.hide()
return True
if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp,
Qt.Key.Key_PageDown]:
return False
self.setFocus()
self.event(event)
self.popup.hide()
return False
class TranscriptionOptionsGroupBox(QGroupBox):
transcription_options: TranscriptionOptions
transcription_options_changed = pyqtSignal(TranscriptionOptions)
@ -972,7 +1096,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
super().__init__(title='', parent=parent)
self.transcription_options = default_transcription_options
layout = QFormLayout(self)
self.form_layout = QFormLayout(self)
self.tasks_combo_box = TasksComboBox(
default_task=self.transcription_options.task,
@ -985,30 +1109,41 @@ class TranscriptionOptionsGroupBox(QGroupBox):
self.languages_combo_box.languageChanged.connect(
self.on_language_changed)
self.model_combo_box = ModelComboBox(
default_model=self.transcription_options.model,
parent=self)
self.model_combo_box.model_changed.connect(self.on_model_changed)
self.advanced_settings_button = AdvancedSettingsButton(self)
self.advanced_settings_button.clicked.connect(
self.open_advanced_settings)
layout.addRow('Task:', self.tasks_combo_box)
layout.addRow('Language:', self.languages_combo_box)
layout.addRow('Model:', self.model_combo_box)
layout.addRow('', self.advanced_settings_button)
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)
self.setLayout(layout)
model_type_combo_box = QComboBox(self)
model_type_combo_box.addItems([model_type.value for model_type in ModelType])
model_type_combo_box.setCurrentText(default_transcription_options.model.model_type.value)
model_type_combo_box.currentTextChanged.connect(self.on_model_type_changed)
self.whisper_model_size_combo_box = QComboBox(self)
self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize])
if default_transcription_options.model.whisper_model_size is not None:
self.whisper_model_size_combo_box.setCurrentText(
default_transcription_options.model.whisper_model_size.value.title())
self.whisper_model_size_combo_box.currentTextChanged.connect(self.on_whisper_model_size_changed)
self.form_layout.addRow('Task:', self.tasks_combo_box)
self.form_layout.addRow('Language:', self.languages_combo_box)
self.form_layout.addRow('Model:', model_type_combo_box)
self.form_layout.addRow('', self.whisper_model_size_combo_box)
self.form_layout.addRow('', self.hugging_face_search_line_edit)
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, False)
self.form_layout.addRow('', self.advanced_settings_button)
self.setLayout(self.form_layout)
def on_language_changed(self, language: str):
self.transcription_options.language = language
self.transcription_options_changed.emit(self.transcription_options)
def on_model_changed(self, model: Model):
self.transcription_options.model = model
self.transcription_options_changed.emit(self.transcription_options)
def on_task_changed(self, task: Task):
self.transcription_options.task = task
self.transcription_options_changed.emit(self.transcription_options)
@ -1032,6 +1167,23 @@ class TranscriptionOptionsGroupBox(QGroupBox):
self.transcription_options = transcription_options
self.transcription_options_changed.emit(transcription_options)
def on_model_type_changed(self, text: str):
model_type = ModelType(text)
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE)
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
(model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP))
self.transcription_options.model_type = model_type
self.transcription_options_changed.emit(self.transcription_options)
def on_whisper_model_size_changed(self, text: str):
model_size = WhisperModelSize(text.lower())
self.transcription_options.model.whisper_model_size = model_size
self.transcription_options_changed.emit(self.transcription_options)
def on_hugging_face_model_changed(self, model: str):
self.transcription_options.hugging_face_model = model
self.transcription_options_changed.emit(self.transcription_options)
class MenuBar(QMenuBar):
import_action_triggered = pyqtSignal()
@ -1080,7 +1232,6 @@ class AdvancedSettingsDialog(QDialog):
self.transcription_options = transcription_options
self.setFixedSize(400, 180)
self.setWindowTitle('Advanced Settings')
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
@ -1091,27 +1242,28 @@ class AdvancedSettingsDialog(QDialog):
default_temperature_text = ', '.join(
[str(temp) for temp in transcription_options.temperature])
self.temperature_line_edit = QLineEdit(default_temperature_text, self)
self.temperature_line_edit = LineEdit(default_temperature_text, self)
self.temperature_line_edit.setPlaceholderText(
'Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"')
self.temperature_line_edit.setMinimumWidth(170)
self.temperature_line_edit.textChanged.connect(
self.on_temperature_changed)
self.temperature_line_edit.setValidator(TemperatureValidator(self))
self.temperature_line_edit.setDisabled(
transcription_options.model.is_whisper_cpp())
self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER)
self.initial_prompt_text_edit = QPlainTextEdit(
transcription_options.initial_prompt, self)
self.initial_prompt_text_edit.textChanged.connect(
self.on_initial_prompt_changed)
self.initial_prompt_text_edit.setDisabled(
transcription_options.model.is_whisper_cpp())
self.initial_prompt_text_edit.setEnabled(
transcription_options.model.model_type == ModelType.WHISPER)
layout.addRow('Temperature:', self.temperature_line_edit)
layout.addRow('Initial Prompt:', self.initial_prompt_text_edit)
layout.addWidget(button_box)
self.setLayout(layout)
self.setFixedSize(self.sizeHint())
def on_temperature_changed(self, text: str):
try:

View file

@ -1,7 +1,9 @@
import enum
import hashlib
import logging
import os
import warnings
from dataclasses import dataclass
from typing import Optional
import requests
@ -9,9 +11,31 @@ import whisper
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from platformdirs import user_cache_dir
from buzz.transcriber import Model
from buzz import transformers_whisper
MODELS_SHA256 = {
class WhisperModelSize(enum.Enum):
TINY = 'tiny'
BASE = 'base'
SMALL = 'small'
MEDIUM = 'medium'
LARGE = 'large'
class ModelType(enum.Enum):
WHISPER = 'Whisper'
WHISPER_CPP = 'Whisper.cpp'
HUGGING_FACE = 'Hugging Face'
@dataclass()
class TranscriptionModel:
model_type: ModelType = ModelType.WHISPER
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY
hugging_face_model_id: Optional[str] = None
WHISPER_CPP_MODELS_SHA256 = {
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
@ -20,53 +44,85 @@ MODELS_SHA256 = {
}
def get_hugging_face_dataset_file_url(author: str, repository_name: str, filename: str):
return f'https://huggingface.co/datasets/{author}/{repository_name}/resolve/main/{filename}'
class ModelLoader(QObject):
progress = pyqtSignal(tuple) # (current, total)
finished = pyqtSignal(str)
error = pyqtSignal(str)
stopped = False
def __init__(self, model: Model, parent: Optional['QObject'] = None) -> None:
def __init__(self, model: TranscriptionModel, parent: Optional['QObject'] = None) -> None:
super().__init__(parent)
self.name = model.model_name()
self.use_whisper_cpp = model.is_whisper_cpp()
self.model_type = model.model_type
self.whisper_model_size = model.whisper_model_size
self.hugging_face_model_id = model.hugging_face_model_id
@pyqtSlot()
def run(self):
if self.model_type == ModelType.WHISPER_CPP:
root_dir = user_cache_dir('Buzz')
model_name = self.whisper_model_size.value
url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
file_path = os.path.join(root_dir, f'ggml-model-whisper-{model_name}.bin')
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
self.download_model(url, file_path, expected_sha256)
return
if self.model_type == ModelType.WHISPER:
root_dir = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
model_name = self.whisper_model_size.value
url = whisper._MODELS[model_name]
file_path = os.path.join(root_dir, os.path.basename(url))
expected_sha256 = url.split('/')[-2]
self.download_model(url, file_path, expected_sha256)
return
if self.model_type == ModelType.HUGGING_FACE:
self.progress.emit((0, 100))
try:
# Loads the model from cache or download if not in cache
transformers_whisper.load_model(self.hugging_face_model_id)
except (FileNotFoundError, EnvironmentError) as exception:
self.error.emit(f'{exception}')
return
self.progress.emit((100, 100))
self.finished.emit(self.hugging_face_model_id)
return
def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]):
try:
if self.use_whisper_cpp:
root = user_cache_dir('Buzz')
url = f'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-{self.name}.bin'
model_path = os.path.join(root, f'ggml-model-whisper-{self.name}.bin')
else:
root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
url = whisper._MODELS[self.name]
model_path = os.path.join(root, os.path.basename(url))
os.makedirs(os.path.dirname(file_path), exist_ok=True)
os.makedirs(root, exist_ok=True)
if os.path.exists(model_path) and not os.path.isfile(model_path):
if os.path.exists(file_path) and not os.path.isfile(file_path):
raise RuntimeError(
f"{model_path} exists and is not a regular file")
f"{file_path} exists and is not a regular file")
expected_sha256 = MODELS_SHA256[self.name] if self.use_whisper_cpp else url.split(
"/")[-2]
if os.path.isfile(model_path):
model_bytes = open(model_path, "rb").read()
if os.path.isfile(file_path):
if expected_sha256 is None:
self.finished.emit(file_path)
return
model_bytes = open(file_path, "rb").read()
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
if model_sha256 == expected_sha256:
self.finished.emit(model_path)
self.finished.emit(file_path)
return
else:
warnings.warn(
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
with requests.get(url, stream=True, timeout=15) as source, open(file_path, 'wb') as output:
source.raise_for_status()
total_size = float(source.headers.get('Content-Length', 0))
current = 0.0
@ -78,12 +134,14 @@ class ModelLoader(QObject):
current += len(chunk)
self.progress.emit((current, total_size))
model_bytes = open(model_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.")
if expected_sha256 is not None:
model_bytes = open(file_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
"model.")
self.finished.emit(model_path)
self.finished.emit(file_path)
except RuntimeError as exc:
self.error.emit(str(exc))
logging.exception('')

View file

@ -15,7 +15,6 @@ from dataclasses import dataclass, field
from multiprocessing.connection import Connection
from threading import Thread
from typing import Any, Callable, List, Optional, Tuple, Union
import typing
import ffmpeg
import numpy as np
@ -25,7 +24,9 @@ import whisper
from PyQt6.QtCore import QObject, QProcess, pyqtSignal, pyqtSlot, QThread
from sounddevice import PortAudioError
from . import transformers_whisper
from .conn import pipe_stderr
from .model_loader import TranscriptionModel, ModelType
# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
@ -53,32 +54,11 @@ class Segment:
text: str
class Model(enum.Enum):
WHISPER_TINY = 'Whisper - Tiny'
WHISPER_BASE = 'Whisper - Base'
WHISPER_SMALL = 'Whisper - Small'
WHISPER_MEDIUM = 'Whisper - Medium'
WHISPER_LARGE = 'Whisper - Large'
WHISPER_CPP_TINY = 'Whisper.cpp - Tiny'
WHISPER_CPP_BASE = 'Whisper.cpp - Base'
WHISPER_CPP_SMALL = 'Whisper.cpp - Small'
WHISPER_CPP_MEDIUM = 'Whisper.cpp - Medium'
WHISPER_CPP_LARGE = 'Whisper.cpp - Large'
def is_whisper_cpp(self) -> bool:
model_type, _ = self.value.split(' - ')
return model_type == 'Whisper.cpp'
def model_name(self) -> str:
_, model_name = self.value.split(' - ')
return model_name.lower()
@dataclass()
class TranscriptionOptions:
language: Optional[str] = None
task: Task = Task.TRANSCRIBE
model: Model = Model.WHISPER_TINY
model: TranscriptionModel = TranscriptionModel()
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
@ -373,7 +353,7 @@ class WhisperFileTranscriber(QObject):
current_process: multiprocessing.Process
progress = pyqtSignal(tuple) # (current, total)
completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
completed = pyqtSignal(list) # List[Segment]
error = pyqtSignal(str)
running = False
read_line_thread: Optional[Thread] = None
@ -390,6 +370,8 @@ class WhisperFileTranscriber(QObject):
self.temperature = task.transcription_options.temperature
self.initial_prompt = task.transcription_options.initial_prompt
self.model_path = task.model_path
self.transcription_options = task.transcription_options
self.transcription_task = task
self.segments = []
@pyqtSlot()
@ -406,13 +388,8 @@ class WhisperFileTranscriber(QObject):
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
self.current_process = multiprocessing.Process(
target=transcribe_whisper,
args=(
send_pipe, model_path, self.file_path,
self.language, self.task, self.word_level_timings,
self.temperature, self.initial_prompt
))
self.current_process = multiprocessing.Process(target=transcribe_whisper,
args=(send_pipe, self.transcription_task))
self.current_process.start()
self.read_line_thread = Thread(
@ -430,8 +407,9 @@ class WhisperFileTranscriber(QObject):
self.read_line_thread.join()
if self.current_process.exitcode != 0:
self.completed.emit((self.current_process.exitcode, []))
# TODO: fix error handling when process crashes
if self.current_process.exitcode != 0 and self.current_process.exitcode is not None:
self.completed.emit([])
self.running = False
@ -458,7 +436,7 @@ class WhisperFileTranscriber(QObject):
) for segment in segments_dict]
self.current_process.join()
# TODO: move this back to the parent thread
self.completed.emit((self.current_process.exitcode, segments))
self.completed.emit(segments)
else:
try:
progress = int(line.split('|')[0].strip().strip('%'))
@ -467,26 +445,30 @@ class WhisperFileTranscriber(QObject):
continue
def transcribe_whisper(
stderr_conn: Connection, model_path: str, file_path: str,
language: Optional[str], task: Task,
word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str):
def transcribe_whisper(stderr_conn: Connection, task: FileTranscriptionTask):
with pipe_stderr(stderr_conn):
model = whisper.load_model(model_path)
if word_level_timings:
stable_whisper.modify_model(model)
result = model.transcribe(
audio=file_path, language=language,
task=task.value, temperature=temperature,
initial_prompt=initial_prompt, pbar=True)
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
model = transformers_whisper.load_model(task.model_path)
language = task.transcription_options.language if task.transcription_options.language is not None else 'en'
result = model.transcribe(audio_path=task.file_path, language=language,
task=task.transcription_options.task.value, verbose=False)
whisper_segments = result.get('segments')
else:
result = model.transcribe(
audio=file_path, language=language, task=task.value, temperature=temperature,
initial_prompt=initial_prompt, verbose=False)
whisper_segments = stable_whisper.group_word_timestamps(
result) if word_level_timings else result.get('segments')
model = whisper.load_model(task.model_path)
if task.transcription_options.word_level_timings:
stable_whisper.modify_model(model)
result = model.transcribe(
audio=task.file_path, language=task.transcription_options.language,
task=task.transcription_options.task.value, temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt, pbar=True)
whisper_segments = stable_whisper.group_word_timestamps(result)
else:
result = model.transcribe(
audio=task.file_path, language=task.transcription_options.language,
task=task.transcription_options.task.value,
temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt, verbose=False)
whisper_segments = result.get('segments')
segments = [
Segment(
@ -638,7 +620,7 @@ class FileTranscriberQueueWorker(QObject):
self.completed.emit()
return
if self.current_task.transcription_options.model.is_whisper_cpp():
if self.current_task.transcription_options.model.model_type == ModelType.WHISPER_CPP:
self.current_transcriber = WhisperCppFileTranscriber(
task=self.current_task)
else:
@ -688,10 +670,9 @@ class FileTranscriberQueueWorker(QObject):
self.current_task.fraction_completed = progress[0] / progress[1]
self.task_updated.emit(self.current_task)
@pyqtSlot(tuple)
def on_task_completed(self, result: Tuple[int, List[Segment]]):
@pyqtSlot(list)
def on_task_completed(self, segments: List[Segment]):
if self.current_task is not None:
_, segments = result
self.current_task.status = FileTranscriptionTask.Status.COMPLETED
self.current_task.segments = segments
self.task_updated.emit(self.current_task)

View file

@ -0,0 +1,56 @@
from typing import Optional
import numpy as np
import whisper
from tqdm import tqdm
from transformers import WhisperProcessor, WhisperForConditionalGeneration
def load_model(model_name_or_path: str):
processor = WhisperProcessor.from_pretrained(model_name_or_path)
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path)
return TransformersWhisper(processor, model)
class TransformersWhisper:
SAMPLE_RATE = whisper.audio.SAMPLE_RATE
N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES
def __init__(self, processor: WhisperProcessor, model: WhisperForConditionalGeneration):
self.processor = processor
self.model = model
# Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and
# timestamps are available. See: https://github.com/huggingface/transformers/issues/19887,
# https://github.com/huggingface/transformers/pull/20620.
def transcribe(self, audio_path: str, language: str, task: str, verbose: Optional[bool] = None):
audio: np.ndarray = whisper.load_audio(audio_path, sr=self.SAMPLE_RATE)
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(task=task, language=language)
segments = []
all_predicted_ids = []
num_samples = audio.size
seek = 0
with tqdm(total=num_samples, unit='samples', disable=verbose is not False) as progress_bar:
while seek < num_samples:
chunk = audio[seek: seek + self.N_SAMPLES_IN_CHUNK]
input_features = self.processor(chunk, return_tensors="pt",
sampling_rate=self.SAMPLE_RATE).input_features
predicted_ids = self.model.generate(input_features)
all_predicted_ids.extend(predicted_ids)
text: str = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
if text.strip() != '':
segments.append({
'start': seek / self.SAMPLE_RATE,
'end': min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) / self.SAMPLE_RATE,
'text': text
})
progress_bar.update(min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek)
seek += self.N_SAMPLES_IN_CHUNK
return {
'text': self.processor.batch_decode(all_predicted_ids, skip_special_tokens=True)[0],
'segments': segments
}

View file

@ -6,19 +6,19 @@ from unittest.mock import Mock, patch
import pytest
import sounddevice
from PyQt6.QtCore import QSize, Qt
from PyQt6.QtGui import QValidator
from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget
from PyQt6.QtGui import QValidator, QKeyEvent
from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget, QApplication
from pytestqt.qtbot import QtBot
from buzz.cache import TasksCache
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
AudioDevicesComboBox, DownloadModelProgressDialog,
FileTranscriberWidget, LanguagesComboBox, MainWindow,
ModelComboBox, RecordingTranscriberWidget,
RecordingTranscriberWidget,
TemperatureValidator, TextDisplayBox,
TranscriptionTasksTableWidget, TranscriptionViewerWidget)
TranscriptionTasksTableWidget, TranscriptionViewerWidget, HuggingFaceSearchLineEdit)
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
Model, Segment, TranscriptionOptions)
Segment, TranscriptionOptions)
class TestApplication:
@ -48,20 +48,6 @@ class TestLanguagesComboBox:
assert languages_combo_box.currentText() == 'Detect Language'
class TestModelComboBox:
model_combo_box = ModelComboBox(
default_model=Model.WHISPER_CPP_BASE, parent=None)
def test_should_show_qualities(self):
assert self.model_combo_box.itemText(0) == 'Whisper - Tiny'
assert self.model_combo_box.itemText(1) == 'Whisper - Base'
assert self.model_combo_box.itemText(2) == 'Whisper - Small'
assert self.model_combo_box.itemText(3) == 'Whisper - Medium'
def test_should_select_default_model(self):
assert self.model_combo_box.currentText() == 'Whisper.cpp - Base'
class TestAudioDevicesComboBox:
mock_query_devices = [
{'name': 'Background Music', 'index': 0, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2,
@ -198,10 +184,9 @@ class TestFileTranscriberWidget:
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], parent=None)
def test_should_set_window_title_and_size(self, qtbot: QtBot):
def test_should_set_window_title(self, qtbot: QtBot):
qtbot.addWidget(self.widget)
assert self.widget.windowTitle() == 'whisper-french.mp3'
assert self.widget.size() == QSize(420, 270)
def test_should_emit_triggered_event(self, qtbot: QtBot):
widget = FileTranscriberWidget(
@ -217,7 +202,6 @@ class TestFileTranscriberWidget:
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
0][0]
assert transcription_options.language is None
assert transcription_options.model == Model.WHISPER_TINY
assert file_transcription_options.file_paths == [
'testdata/whisper-french.mp3']
assert len(model_path) > 0
@ -232,8 +216,7 @@ class TestAboutDialog:
class TestAdvancedSettingsDialog:
def test_should_update_advanced_settings(self, qtbot: QtBot):
dialog = AdvancedSettingsDialog(
transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt',
model=Model.WHISPER_CPP_BASE))
transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt'))
qtbot.add_widget(dialog)
transcription_options_mock = Mock()
@ -333,7 +316,50 @@ class TestTranscriptionTasksTableWidget:
class TestRecordingTranscriberWidget:
widget = RecordingTranscriberWidget()
def test_should_set_window_title_and_size(self, qtbot: QtBot):
def test_should_set_window_title(self, qtbot: QtBot):
qtbot.add_widget(self.widget)
assert self.widget.windowTitle() == 'Live Recording'
assert self.widget.size() == QSize(400, 520)
class TestHuggingFaceSearchLineEdit:
def test_should_update_selected_model_on_type(self, qtbot: QtBot):
widget = HuggingFaceSearchLineEdit()
qtbot.add_widget(widget)
mock_model_selected = Mock()
widget.model_selected.connect(mock_model_selected)
self._set_text_and_wait_response(qtbot, widget)
mock_model_selected.assert_called_with('openai/whisper-tiny')
def test_should_show_list_of_models(self, qtbot: QtBot):
widget = HuggingFaceSearchLineEdit()
qtbot.add_widget(widget)
self._set_text_and_wait_response(qtbot, widget)
assert widget.popup.count() > 0
assert 'openai/whisper-tiny' in widget.popup.item(0).text()
def test_should_select_model_from_list(self, qtbot: QtBot):
widget = HuggingFaceSearchLineEdit()
qtbot.add_widget(widget)
mock_model_selected = Mock()
widget.model_selected.connect(mock_model_selected)
self._set_text_and_wait_response(qtbot, widget)
# press down arrow and enter to select next item
QApplication.sendEvent(widget.popup,
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier))
QApplication.sendEvent(widget.popup,
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Enter, Qt.KeyboardModifier.NoModifier))
mock_model_selected.assert_called_with('openai/whisper-tiny.en')
@staticmethod
def _set_text_and_wait_response(qtbot: QtBot, widget: HuggingFaceSearchLineEdit):
with qtbot.wait_signal(widget.network_manager.finished, timeout=30 * 1000):
widget.setText('openai/whisper-tiny')
widget.textEdited.emit('openai/whisper-tiny')

14
tests/model_loader.py Normal file
View file

@ -0,0 +1,14 @@
from buzz.model_loader import ModelLoader, TranscriptionModel
def get_model_path(transcription_model: TranscriptionModel) -> str:
model_loader = ModelLoader(model=transcription_model)
model_path = ''
def on_load_model(path: str):
nonlocal model_path
model_path = path
model_loader.finished.connect(on_load_model)
model_loader.run()
return model_path

View file

@ -8,30 +8,18 @@ from unittest.mock import Mock
import pytest
from pytestqt.qtbot import QtBot
from buzz.model_loader import ModelLoader
from buzz.transcriber import (FileTranscriberQueueWorker, FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber, Segment, Task,
WhisperCpp, WhisperCppFileTranscriber,
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber,
Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
WhisperFileTranscriber,
get_default_output_file_path, to_timestamp,
whisper_cpp_params, write_output, TranscriptionOptions, Model)
def get_model_path(model: Model) -> str:
model_loader = ModelLoader(model=model)
model_path = ''
def on_load_model(path: str):
nonlocal model_path
model_path = path
model_loader.finished.connect(on_load_model)
model_loader.run()
return model_path
whisper_cpp_params, write_output, TranscriptionOptions)
from tests.model_loader import get_model_path
class TestRecordingTranscriber:
def test_transcriber(self):
model_path = get_model_path(Model.WHISPER_CPP_TINY)
model_path = get_model_path(transcription_model=TranscriptionModel())
transcriber = RecordingTranscriber(
model_path=model_path, use_whisper_cpp=True, language='en',
task=Task.TRANSCRIBE)
@ -49,11 +37,15 @@ class TestWhisperCppFileTranscriber:
file_transcription_options = FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3'])
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
word_level_timings=word_level_timings)
word_level_timings=word_level_timings,
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY))
model_path = get_model_path(Model.WHISPER_CPP_TINY)
model_path = get_model_path(transcription_options.model)
transcriber = WhisperCppFileTranscriber(
task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3', transcription_options=transcription_options, file_transcription_options=file_transcription_options, model_path=model_path))
task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3',
transcription_options=transcription_options,
file_transcription_options=file_transcription_options, model_path=model_path))
mock_progress = Mock()
mock_completed = Mock()
transcriber.progress.connect(mock_progress)
@ -81,44 +73,51 @@ class TestWhisperFileTranscriber:
assert srt.endswith('.srt')
@pytest.mark.parametrize(
'word_level_timings,expected_segments',
'word_level_timings,expected_segments,model,check_progress',
[
(False, [
Segment(
0, 6560,
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances'),
]),
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')])
(False, [Segment(0, 6560,
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances')],
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')],
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
(False, [Segment(0, 8517,
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances '
'et des apprenances de français.')],
TranscriptionModel(model_type=ModelType.HUGGING_FACE,
hugging_face_model_id='openai/whisper-tiny'), False)
])
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
model_path = get_model_path(Model.WHISPER_TINY)
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment],
model: TranscriptionModel, check_progress):
mock_progress = Mock()
mock_completed = Mock()
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
word_level_timings=word_level_timings)
word_level_timings=word_level_timings,
model=model)
model_path = get_model_path(transcription_options.model)
file_transcription_options = FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3'])
transcriber = WhisperFileTranscriber(
task=FileTranscriptionTask(transcription_options=transcription_options, file_transcription_options=file_transcription_options, file_path='testdata/whisper-french.mp3', model_path=model_path))
task=FileTranscriptionTask(transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
file_path='testdata/whisper-french.mp3', model_path=model_path))
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
transcriber.run()
# Reports progress at 0, 0<progress<100, and 100
assert any(
[call_args.args[0] == (0, 100) for call_args in mock_progress.call_args_list])
assert any(
[call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
assert any(
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in
mock_progress.call_args_list])
if check_progress:
# Reports progress at 0, 0<progress<100, and 100
assert any(
[call_args.args[0] == (0, 100) for call_args in mock_progress.call_args_list])
assert any(
[call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
assert any(
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in
mock_progress.call_args_list])
mock_completed.assert_called()
exit_code, segments = mock_completed.call_args[0][0]
assert exit_code is 0
segments = mock_completed.call_args[0][0]
for (i, expected_segment) in enumerate(expected_segments):
assert segments[i] == expected_segment
@ -128,14 +127,17 @@ class TestWhisperFileTranscriber:
if os.path.exists(output_file_path):
os.remove(output_file_path)
model_path = get_model_path(Model.WHISPER_TINY)
file_transcription_options = FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3'])
transcription_options = TranscriptionOptions(
language='fr', task=Task.TRANSCRIBE, word_level_timings=False)
language='fr', task=Task.TRANSCRIBE, word_level_timings=False,
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
model_path = get_model_path(transcription_options.model)
transcriber = WhisperFileTranscriber(
task=FileTranscriptionTask(model_path=model_path, transcription_options=transcription_options, file_transcription_options=file_transcription_options, file_path='testdata/whisper-french.mp3'))
task=FileTranscriptionTask(model_path=model_path, transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
file_path='testdata/whisper-french.mp3'))
transcriber.run()
time.sleep(1)
transcriber.stop()
@ -152,7 +154,9 @@ class TestToTimestamp:
class TestWhisperCpp:
def test_transcribe(self):
model_path = get_model_path(Model.WHISPER_CPP_TINY)
transcription_options = TranscriptionOptions(
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
model_path = get_model_path(transcription_options.model)
whisper_cpp = WhisperCpp(model=model_path)
params = whisper_cpp_params(
@ -168,8 +172,8 @@ class TestWhisperCpp:
[
(OutputFormat.TXT, 'Bien venue dans\n'),
(
OutputFormat.SRT,
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
OutputFormat.SRT,
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
(OutputFormat.VTT,
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
])

View file

@ -0,0 +1,10 @@
from buzz.transformers_whisper import load_model
class TestTransformersWhisper:
def test_should_transcribe(self):
model = load_model('openai/whisper-tiny')
result = model.transcribe(
audio_path='testdata/whisper-french.mp3', language='fr', task='transcribe')
assert 'Bienvenue dans Passe' in result['text']