Add models preferences (#421)

This commit is contained in:
Chidi Williams 2023-04-28 21:28:05 +00:00 committed by GitHub
parent 0a4be2b195
commit cb5ad74620
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 582 additions and 246 deletions

View file

@ -12,16 +12,15 @@ import sounddevice
from PyQt6 import QtGui
from PyQt6.QtCore import (QObject, Qt, QThread,
QTimer, QUrl, pyqtSignal, QModelIndex, QPoint,
QUrlQuery, QMetaObject, QEvent)
QUrlQuery, QMetaObject, QEvent, QThreadPool)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor, QValidator, QKeyEvent, QPainter, QColor)
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QDialogButtonBox, QFileDialog, QLabel, QLineEdit,
QMainWindow, QMessageBox, QPlainTextEdit,
QDialogButtonBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QGroupBox, QTableWidget,
QMenuBar, QFormLayout, QTableWidgetItem,
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)
QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)
from whisper import tokenizer
from buzz.cache import TasksCache
@ -30,7 +29,8 @@ from .action import Action
from .assets import get_asset_path
from .icon import Icon
from .locale import _
from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, get_local_model_path, \
ModelDownloader
from .paths import file_paths_as_title
from .recording import RecordingAmplitudeListener
from .settings.settings import Settings, APP_NAME
@ -43,6 +43,8 @@ from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Ou
FileTranscriberQueueWorker, FileTranscriptionTask, RecordingTranscriber, LOADED_WHISPER_DLL,
DEFAULT_WHISPER_TEMPERATURE)
from .widgets.line_edit import LineEdit
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from .widgets.model_type_combo_box import ModelTypeComboBox
from .widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
from .widgets.preferences_dialog import PreferencesDialog
from .widgets.toolbar import ToolBar
@ -163,33 +165,6 @@ class RecordButton(QPushButton):
self.setDefault(False)
class DownloadModelProgressDialog(QProgressDialog):
start_time: datetime
def __init__(self, parent: Optional[QWidget], *args) -> None:
super().__init__(_('Downloading model (0%, unknown time remaining)'),
_('Cancel'), 0, 100, parent, *args)
# Setting this to a high value to avoid showing the dialog for models that
# are checked locally but set progress to 0 immediately, i.e. Hugging Face or Faster Whisper models
self.setMinimumDuration(10_000)
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.start_time = datetime.now()
self.setFixedSize(self.size())
def set_fraction_completed(self, fraction_completed: float) -> None:
self.setValue(int(fraction_completed * self.maximum()))
if fraction_completed > 0.0:
time_spent = (datetime.now() - self.start_time).total_seconds()
time_left = (time_spent / fraction_completed) - time_spent
self.setLabelText(_('Downloading model') +
f' ({fraction_completed :.0%}, ' +
humanize.naturaldelta(time_left) + ')')
def show_model_download_error_dialog(parent: QWidget, error: str):
message = parent.tr(
'An error occurred while loading the Whisper model') + \
@ -200,9 +175,8 @@ def show_model_download_error_dialog(parent: QWidget, error: str):
class FileTranscriberWidget(QWidget):
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
model_loader: Optional[ModelLoader] = None
transcriber_thread: Optional[QThread] = None
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
model_loader: Optional[ModelDownloader] = None
file_transcription_options: FileTranscriptionOptions
transcription_options: TranscriptionOptions
is_transcribing = False
@ -291,23 +265,16 @@ class FileTranscriberWidget(QWidget):
def on_click_run(self):
self.run_button.setDisabled(True)
self.transcriber_thread = QThread()
self.model_loader = ModelLoader(model=self.transcription_options.model)
self.model_loader.moveToThread(self.transcriber_thread)
model_path = get_local_model_path(model=self.transcription_options.model)
if model_path is not None:
self.on_model_loaded(model_path)
return
self.transcriber_thread.started.connect(self.model_loader.run)
self.model_loader.finished.connect(
self.transcriber_thread.quit)
self.model_loader.progress.connect(self.on_download_model_progress)
self.model_loader.error.connect(self.on_download_model_error)
self.model_loader.error.connect(self.model_loader.deleteLater)
self.model_loader.finished.connect(self.on_model_loaded)
self.model_loader.finished.connect(self.model_loader.deleteLater)
self.transcriber_thread.start()
self.model_loader = ModelDownloader(model=self.transcription_options.model)
self.model_loader.signals.progress.connect(self.on_download_model_progress)
self.model_loader.signals.error.connect(self.on_download_model_error)
self.model_loader.signals.finished.connect(self.on_model_loaded)
QThreadPool().globalInstance().start(self.model_loader)
def on_model_loaded(self, model_path: str):
self.reset_transcriber_controls()
@ -320,12 +287,13 @@ class FileTranscriberWidget(QWidget):
(current_size, total_size) = progress
if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = DownloadModelProgressDialog(parent=self)
self.model_download_progress_dialog = ModelDownloadProgressDialog(
model_type=self.transcription_options.model.model_type, parent=self)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)
def on_download_model_error(self, error: str):
self.reset_model_download()
@ -337,7 +305,7 @@ class FileTranscriberWidget(QWidget):
def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
self.model_loader.stop()
self.model_loader.cancel()
self.reset_model_download()
def reset_model_download(self):
@ -349,8 +317,8 @@ class FileTranscriberWidget(QWidget):
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
if self.transcriber_thread is not None:
self.transcriber_thread.wait()
if self.model_loader is not None:
self.model_loader.cancel()
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task)
@ -432,9 +400,9 @@ class RecordingTranscriberWidget(QWidget):
current_status: 'RecordingStatus'
transcription_options: TranscriptionOptions
selected_device_id: Optional[int]
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
transcriber: Optional[RecordingTranscriber] = None
model_loader: Optional[ModelLoader] = None
model_loader: Optional[ModelDownloader] = None
transcription_thread: Optional[QThread] = None
recording_amplitude_listener: Optional[RecordingAmplitudeListener] = None
device_sample_rate: Optional[int] = None
@ -541,29 +509,35 @@ class RecordingTranscriberWidget(QWidget):
def start_recording(self):
self.record_button.setDisabled(True)
model_path = get_local_model_path(model=self.transcription_options.model)
if model_path is not None:
self.on_model_loaded(model_path)
return
self.model_loader = ModelDownloader(model=self.transcription_options.model)
self.model_loader.signals.progress.connect(self.on_download_model_progress)
self.model_loader.signals.error.connect(self.on_download_model_error)
self.model_loader.signals.finished.connect(self.on_model_loaded)
QThreadPool().globalInstance().start(self.model_loader)
def on_model_loaded(self, model_path: str):
self.reset_recording_controls()
self.model_loader = None
self.transcription_thread = QThread()
self.model_loader = ModelLoader(model=self.transcription_options.model)
# TODO: make runnable
self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id,
sample_rate=self.device_sample_rate,
transcription_options=self.transcription_options)
transcription_options=self.transcription_options,
model_path=model_path)
self.model_loader.moveToThread(self.transcription_thread)
self.transcriber.moveToThread(self.transcription_thread)
self.transcription_thread.started.connect(self.model_loader.run)
self.transcription_thread.started.connect(self.transcriber.start)
self.transcription_thread.finished.connect(
self.transcription_thread.deleteLater)
self.model_loader.finished.connect(self.reset_recording_controls)
self.model_loader.finished.connect(self.transcriber.start)
self.model_loader.finished.connect(self.model_loader.deleteLater)
self.model_loader.progress.connect(
self.on_download_model_progress)
self.model_loader.error.connect(self.on_download_model_error)
self.transcriber.transcription.connect(self.on_next_transcription)
self.transcriber.finished.connect(self.on_transcriber_finished)
@ -580,12 +554,13 @@ class RecordingTranscriberWidget(QWidget):
(current_size, total_size) = progress
if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = DownloadModelProgressDialog(parent=self)
self.model_download_progress_dialog = ModelDownloadProgressDialog(
model_type=self.transcription_options.model.model_type, parent=self)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)
def set_recording_status_stopped(self):
self.record_button.set_stopped()
@ -640,9 +615,7 @@ class RecordingTranscriberWidget(QWidget):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.reset_record_button()
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None
self.reset_model_download()
def reset_record_button(self):
self.record_button.setEnabled(True)
@ -651,6 +624,9 @@ class RecordingTranscriberWidget(QWidget):
self.audio_meter_widget.update_amplitude(amplitude)
def closeEvent(self, event: QCloseEvent) -> None:
if self.model_loader is not None:
self.model_loader.cancel()
self.stop_recording()
if self.recording_amplitude_listener is not None:
self.recording_amplitude_listener.stop_recording()
@ -1264,17 +1240,10 @@ class TranscriptionOptionsGroupBox(QGroupBox):
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)
self.model_type_combo_box = QComboBox(self)
if model_types is None:
model_types = [model_type for model_type in ModelType]
for model_type in model_types:
# Hide Whisper.cpp option is whisper.dll did not load correctly.
# See: https://github.com/chidiwilliams/buzz/issues/274, https://github.com/chidiwilliams/buzz/issues/197
if model_type == ModelType.WHISPER_CPP and LOADED_WHISPER_DLL is False:
continue
self.model_type_combo_box.addItem(model_type.value)
self.model_type_combo_box.setCurrentText(default_transcription_options.model.model_type.value)
self.model_type_combo_box.currentTextChanged.connect(self.on_model_type_changed)
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
default_model=default_transcription_options.model.model_type,
parent=self)
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
self.whisper_model_size_combo_box = QComboBox(self)
self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize])
@ -1339,8 +1308,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
model_type == ModelType.FASTER_WHISPER))
self.form_layout.setRowVisible(self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API)
def on_model_type_changed(self, text: str):
model_type = ModelType(text)
def on_model_type_changed(self, model_type: ModelType):
self.transcription_options.model.model_type = model_type
self.reset_visible_rows()
self.transcription_options_changed.emit(self.transcription_options)

View file

@ -2,17 +2,18 @@ import enum
import hashlib
import logging
import os
import tempfile
import warnings
from dataclasses import dataclass
from typing import Optional
import faster_whisper
import huggingface_hub
import requests
import whisper
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from PyQt6.QtCore import QObject, pyqtSignal, QRunnable
from platformdirs import user_cache_dir
from buzz import transformers_whisper
from tqdm.auto import tqdm
class WhisperModelSize(enum.Enum):
@ -51,121 +52,185 @@ def get_hugging_face_dataset_file_url(author: str, repository_name: str, filenam
return f'https://huggingface.co/datasets/{author}/{repository_name}/resolve/main/{filename}'
class ModelLoader(QObject):
progress = pyqtSignal(tuple) # (current, total)
finished = pyqtSignal(str)
error = pyqtSignal(str)
stopped = False
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
root_dir = user_cache_dir('Buzz')
return os.path.join(root_dir, f'ggml-model-whisper-{size.value}.bin')
def __init__(self, model: TranscriptionModel, parent: Optional['QObject'] = None) -> None:
super().__init__(parent)
self.model_type = model.model_type
self.whisper_model_size = model.whisper_model_size
self.hugging_face_model_id = model.hugging_face_model_id
@pyqtSlot()
def run(self):
if self.model_type == ModelType.WHISPER_CPP:
root_dir = user_cache_dir('Buzz')
model_name = self.whisper_model_size.value
def get_whisper_file_path(size: WhisperModelSize) -> str:
root_dir = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
url = whisper._MODELS[size.value]
return os.path.join(root_dir, os.path.basename(url))
def get_local_model_path(model: TranscriptionModel) -> Optional[str]:
if model.model_type == ModelType.WHISPER_CPP:
file_path = get_whisper_cpp_file_path(size=model.whisper_model_size)
if not os.path.exists(file_path) or not os.path.isfile(file_path):
return None
return file_path
if model.model_type == ModelType.WHISPER:
file_path = get_whisper_file_path(size=model.whisper_model_size)
if not os.path.exists(file_path) or not os.path.isfile(file_path):
return None
return file_path
if model.model_type == ModelType.FASTER_WHISPER:
try:
return download_faster_whisper_model(size=model.whisper_model_size.value, local_files_only=True)
except (ValueError, FileNotFoundError):
return None
if model.model_type == ModelType.OPEN_AI_WHISPER_API:
return ''
if model.model_type == ModelType.HUGGING_FACE:
try:
return huggingface_hub.snapshot_download(model.hugging_face_model_id, local_files_only=True)
except (ValueError, FileNotFoundError):
return None
raise Exception("Unknown model type")
def download_faster_whisper_model(size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None):
if size not in faster_whisper.utils._MODELS:
raise ValueError(
"Invalid model size '%s', expected one of: %s" % (size, ", ".join(faster_whisper.utils._MODELS))
)
repo_id = "guillaumekln/faster-whisper-%s" % size
allow_patterns = [
"config.json",
"model.bin",
"tokenizer.json",
"vocabulary.txt",
]
return huggingface_hub.snapshot_download(repo_id, allow_patterns=allow_patterns, local_files_only=local_files_only,
tqdm_class=tqdm_class)
class ModelDownloader(QRunnable):
class Signals(QObject):
finished = pyqtSignal(str)
progress = pyqtSignal(tuple) # (current, total)
error = pyqtSignal(str)
def __init__(self, model: TranscriptionModel):
super().__init__()
self.signals = self.Signals()
self.model = model
self.stopped = False
def run(self) -> None:
if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.value
url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
file_path = os.path.join(root_dir, f'ggml-model-whisper-{model_name}.bin')
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
self.download_model(url, file_path, expected_sha256)
return self.download_model_to_path(url=url, file_path=file_path, expected_sha256=expected_sha256)
elif self.model_type == ModelType.WHISPER:
root_dir = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
model_name = self.whisper_model_size.value
url = whisper._MODELS[model_name]
file_path = os.path.join(root_dir, os.path.basename(url))
if self.model.model_type == ModelType.WHISPER:
url = whisper._MODELS[self.model.whisper_model_size.value]
file_path = get_whisper_file_path(size=self.model.whisper_model_size)
expected_sha256 = url.split('/')[-2]
self.download_model(url, file_path, expected_sha256)
return self.download_model_to_path(url=url, file_path=file_path, expected_sha256=expected_sha256)
elif self.model_type == ModelType.HUGGING_FACE:
self.progress.emit((0, 100))
progress = self.signals.progress
try:
# Loads the model from cache or download if not in cache
transformers_whisper.load_model(self.hugging_face_model_id)
except (FileNotFoundError, EnvironmentError) as exception:
self.error.emit(f'{exception}')
return
# gross abuse of power...
class _tqdm(tqdm):
def update(self, n: float | None = ...) -> bool | None:
progress.emit((n, self.total))
return super().update(n)
self.progress.emit((100, 100))
file_path = self.hugging_face_model_id
def close(self):
progress.emit((self.n, self.total))
return super().close()
elif self.model_type == ModelType.OPEN_AI_WHISPER_API:
file_path = ""
if self.model.model_type == ModelType.FASTER_WHISPER:
model_path = download_faster_whisper_model(size=self.model.whisper_model_size.value, tqdm_class=_tqdm)
self.signals.finished.emit(model_path)
return
elif self.model_type == ModelType.FASTER_WHISPER:
self.progress.emit((0, 100))
file_path = faster_whisper.download_model(size=self.whisper_model_size.value)
self.progress.emit((100, 100))
if self.model.model_type == ModelType.HUGGING_FACE:
model_path = huggingface_hub.snapshot_download(self.model.hugging_face_model_id, tqdm_class=_tqdm)
self.signals.finished.emit(model_path)
return
else:
raise Exception("Invalid model type: " + self.model_type.value)
if self.model.model_type == ModelType.OPEN_AI_WHISPER_API:
self.signals.finished.emit('')
return
self.finished.emit(file_path)
raise Exception("Invalid model type: " + self.model.model_type.value)
def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]):
def download_model_to_path(self, url: str, file_path: str, expected_sha256: Optional[str]):
try:
logging.debug(f'Downloading model from {url} to {file_path}')
os.makedirs(os.path.dirname(file_path), exist_ok=True)
if os.path.exists(file_path) and not os.path.isfile(file_path):
raise RuntimeError(
f"{file_path} exists and is not a regular file")
if os.path.isfile(file_path):
if expected_sha256 is None:
return file_path
model_bytes = open(file_path, "rb").read()
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
if model_sha256 == expected_sha256:
return file_path
else:
warnings.warn(
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(file_path, 'wb') as output:
source.raise_for_status()
total_size = float(source.headers.get('Content-Length', 0))
current = 0.0
self.progress.emit((current, total_size))
for chunk in source.iter_content(chunk_size=8192):
if self.stopped:
return
output.write(chunk)
current += len(chunk)
self.progress.emit((current, total_size))
if expected_sha256 is not None:
model_bytes = open(file_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
"model.")
logging.debug('Downloaded model')
return file_path
except RuntimeError as exc:
self.error.emit(str(exc))
logging.exception('')
downloaded = self.download_model(url, file_path, expected_sha256)
if downloaded:
self.signals.finished.emit(file_path)
except requests.RequestException:
self.error.emit('A connection error occurred')
logging.exception('')
except Exception:
self.error.emit('An unknown error occurred')
self.signals.error.emit('A connection error occurred')
logging.exception('')
except Exception as exc:
self.signals.error.emit(str(exc))
logging.exception(exc)
def stop(self):
def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]) -> bool:
logging.debug(f'Downloading model from {url} to {file_path}')
os.makedirs(os.path.dirname(file_path), exist_ok=True)
if os.path.exists(file_path) and not os.path.isfile(file_path):
raise RuntimeError(
f"{file_path} exists and is not a regular file")
if os.path.isfile(file_path):
if expected_sha256 is None:
return True
model_bytes = open(file_path, "rb").read()
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
if model_sha256 == expected_sha256:
return True
else:
warnings.warn(
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
tmp_file = tempfile.mktemp()
logging.debug('Downloading to temporary file = %s', tmp_file)
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(tmp_file, 'wb') as output:
source.raise_for_status()
total_size = float(source.headers.get('Content-Length', 0))
current = 0.0
self.signals.progress.emit((current, total_size))
for chunk in source.iter_content(chunk_size=8192):
if self.stopped:
return False
output.write(chunk)
current += len(chunk)
self.signals.progress.emit((current, total_size))
if expected_sha256 is not None:
model_bytes = open(tmp_file, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
"model.")
logging.debug('Downloaded model')
os.rename(tmp_file, file_path)
logging.debug('Moved file from %s to %s', tmp_file, file_path)
return True
def cancel(self):
self.stopped = True

View file

@ -107,20 +107,23 @@ class RecordingTranscriber(QObject):
MAX_QUEUE_SIZE = 10
def __init__(self, transcription_options: TranscriptionOptions,
input_device_index: Optional[int], sample_rate: int, parent: Optional[QObject] = None) -> None:
input_device_index: Optional[int], sample_rate: int, model_path: str,
parent: Optional[QObject] = None) -> None:
super().__init__(parent)
self.transcription_options = transcription_options
self.current_stream = None
self.input_device_index = input_device_index
self.sample_rate = sample_rate
self.model_path = model_path
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
# pause queueing if more than 3 batches behind
self.max_queue_size = 3 * self.n_batch_samples
self.queue = np.ndarray([], dtype=np.float32)
self.mutex = threading.Lock()
@pyqtSlot(str)
def start(self, model_path: str):
def start(self):
model_path = self.model_path
if self.transcription_options.model.model_type == ModelType.WHISPER:
model = whisper.load_model(model_path)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:

View file

@ -0,0 +1,46 @@
from datetime import datetime
from typing import Optional
import humanize
from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import QProgressDialog, QWidget, QPushButton
from buzz.locale import _
from buzz.model_loader import ModelType
class ModelDownloadProgressDialog(QProgressDialog):
def __init__(self, model_type: ModelType, parent: Optional[QWidget] = None, modality=Qt.WindowModality.WindowModal):
super().__init__(parent)
self.cancelable = model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP
self.start_time = datetime.now()
self.setRange(0, 100)
self.setMinimumDuration(0)
self.setWindowModality(modality)
self.update_label_text(0)
if not self.cancelable:
cancel_button = QPushButton('Cancel', self)
cancel_button.setEnabled(False)
self.setCancelButton(cancel_button)
def update_label_text(self, fraction_completed: float):
label_text = f"{_('Downloading model')} ({fraction_completed:.0%}"
if fraction_completed > 0:
time_spent = (datetime.now() - self.start_time).total_seconds()
time_left = (time_spent / fraction_completed) - time_spent
label_text += f', {humanize.naturaldelta(time_left)} remaining'
label_text += ')'
self.setLabelText(label_text)
def set_value(self, fraction_completed: float):
if self.wasCanceled():
return
self.setValue(int(fraction_completed * self.maximum()))
self.update_label_text(fraction_completed)
def cancel(self) -> None:
if self.cancelable:
super().cancel()

View file

@ -0,0 +1,32 @@
from typing import Optional, List
from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QComboBox, QWidget
from buzz.model_loader import ModelType
from buzz.transcriber import LOADED_WHISPER_DLL
class ModelTypeComboBox(QComboBox):
changed = pyqtSignal(ModelType)
def __init__(self, model_types: Optional[List[ModelType]] = None, default_model: Optional[ModelType] = None,
parent: Optional[QWidget] = None):
super().__init__(parent)
if model_types is None:
model_types = [model_type for model_type in ModelType]
for model_type in model_types:
# Hide Whisper.cpp option is whisper.dll did not load correctly.
# See: https://github.com/chidiwilliams/buzz/issues/274, https://github.com/chidiwilliams/buzz/issues/197
if model_type == ModelType.WHISPER_CPP and LOADED_WHISPER_DLL is False:
continue
self.addItem(model_type.value)
self.currentTextChanged.connect(self.on_text_changed)
if default_model is not None:
self.setCurrentText(default_model.value)
def on_text_changed(self, text: str):
self.changed.emit(ModelType(text))

View file

@ -0,0 +1,123 @@
from typing import Optional
from PyQt6.QtCore import Qt, QThreadPool
from PyQt6.QtWidgets import QWidget, QFormLayout, QTreeWidget, QTreeWidgetItem, QPushButton, QMessageBox
from buzz.locale import _
from buzz.model_loader import ModelType, WhisperModelSize, get_local_model_path, TranscriptionModel, ModelDownloader
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
class ModelsPreferencesWidget(QWidget):
def __init__(self, progress_dialog_modality: Optional[Qt.WindowModality] = None, parent: Optional[QWidget] = None):
super().__init__(parent)
self.model_downloader: Optional[ModelDownloader] = None
self.model = TranscriptionModel(model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY)
self.progress_dialog_modality = progress_dialog_modality
self.progress_dialog: Optional[ModelDownloadProgressDialog] = None
layout = QFormLayout()
model_type_combo_box = ModelTypeComboBox(
model_types=[ModelType.WHISPER, ModelType.WHISPER_CPP, ModelType.FASTER_WHISPER],
default_model=self.model.model_type, parent=self)
model_type_combo_box.changed.connect(self.on_model_type_changed)
layout.addRow('Group', model_type_combo_box)
self.model_list_widget = QTreeWidget()
self.model_list_widget.setColumnCount(1)
self.reset_model_size_list()
self.model_list_widget.currentItemChanged.connect(self.on_model_size_changed)
layout.addWidget(self.model_list_widget)
self.download_button = QPushButton(_('Download'))
self.download_button.clicked.connect(self.on_download_button_clicked)
self.reset_download_button()
layout.addWidget(self.download_button)
self.setLayout(layout)
def on_model_size_changed(self, current: QTreeWidgetItem, _: QTreeWidgetItem):
if current is None:
return
item_data = current.data(0, Qt.ItemDataRole.UserRole)
if item_data is None:
return
self.model.whisper_model_size = item_data
self.reset_download_button()
def reset_download_button(self):
model_path = get_local_model_path(model=self.model)
self.download_button.setEnabled(model_path is None)
def on_model_type_changed(self, model_type: ModelType):
self.model.model_type = model_type
self.reset_model_size_list()
self.reset_download_button()
def on_download_button_clicked(self):
self.progress_dialog = ModelDownloadProgressDialog(model_type=self.model.model_type,
modality=self.progress_dialog_modality, parent=self)
self.progress_dialog.canceled.connect(self.on_progress_dialog_canceled)
self.download_button.setEnabled(False)
self.model_downloader = ModelDownloader(model=self.model)
self.model_downloader.signals.finished.connect(self.on_download_completed)
self.model_downloader.signals.progress.connect(self.on_download_progress)
self.model_downloader.signals.error.connect(self.on_download_error)
QThreadPool().globalInstance().start(self.model_downloader)
def on_download_completed(self, _: str):
self.progress_dialog = None
self.reset_download_button()
self.reset_model_size_list()
def on_download_error(self, error: str):
self.progress_dialog.cancel()
self.progress_dialog = None
self.reset_download_button()
self.reset_model_size_list()
QMessageBox.warning(self, _('Error'), f'Download failed: {error}')
def on_download_progress(self, progress: tuple):
self.progress_dialog.set_value(float(progress[0]) / progress[1])
def reset_model_size_list(self):
self.model_list_widget.clear()
downloaded_item = QTreeWidgetItem(self.model_list_widget)
downloaded_item.setText(0, _('Downloaded'))
downloaded_item.setFlags(
downloaded_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
available_item = QTreeWidgetItem(self.model_list_widget)
available_item.setText(0, _('Available for Download'))
available_item.setFlags(
available_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
self.model_list_widget.addTopLevelItems([downloaded_item, available_item])
self.model_list_widget.expandToDepth(2)
self.model_list_widget.setHeaderHidden(True)
self.model_list_widget.setAlternatingRowColors(True)
for model_size in WhisperModelSize:
model_path = get_local_model_path(
model=TranscriptionModel(model_type=self.model.model_type, whisper_model_size=model_size))
parent = downloaded_item if model_path is not None else available_item
item = QTreeWidgetItem(parent)
item.setText(0, model_size.value.title())
item.setData(0, Qt.ItemDataRole.UserRole, model_size)
if self.model.whisper_model_size == model_size:
item.setSelected(True)
parent.addChild(item)
def on_progress_dialog_canceled(self):
self.model_downloader.cancel()
self.reset_model_size_list()
self.reset_download_button()

View file

@ -4,8 +4,8 @@ from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QTabWidget, QDialogButtonBox
from buzz.locale import _
from buzz.store.keyring_store import KeyringStore
from buzz.widgets.general_preferences_widget import GeneralPreferencesWidget
from buzz.widgets.models_preferences_widget import ModelsPreferencesWidget
from buzz.widgets.shortcuts_editor_preferences_widget import ShortcutsEditorPreferencesWidget
@ -25,6 +25,9 @@ class PreferencesDialog(QDialog):
general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed)
tab_widget.addTab(general_tab_widget, _('General'))
models_tab_widget = ModelsPreferencesWidget(parent=self)
tab_widget.addTab(models_tab_widget, _('Models'))
shortcuts_table_widget = ShortcutsEditorPreferencesWidget(shortcuts, self)
shortcuts_table_widget.shortcuts_changed.connect(self.shortcuts_changed)
tab_widget.addTab(shortcuts_table_widget, _('Shortcuts'))

13
main.py
View file

@ -38,10 +38,15 @@ if __name__ == "__main__":
log_dir = user_log_dir(appname='Buzz')
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(
filename=os.path.join(log_dir, 'logs.txt'),
level=logging.DEBUG,
format="[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s")
log_format = "[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s"
logging.basicConfig(filename=os.path.join(log_dir, 'logs.txt'), level=logging.DEBUG, format=log_format)
if getattr(sys, 'frozen', False) is False:
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.DEBUG)
stdout_handler.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(stdout_handler)
from buzz.gui import Application

View file

@ -1,4 +1,3 @@
import logging
import multiprocessing
import os.path
import platform
@ -15,16 +14,16 @@ from pytestqt.qtbot import QtBot
from buzz.__version__ import VERSION
from buzz.cache import TasksCache
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, AudioDevicesComboBox, DownloadModelProgressDialog,
FileTranscriberWidget, LanguagesComboBox, MainWindow,
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, AudioDevicesComboBox, FileTranscriberWidget,
LanguagesComboBox, MainWindow,
RecordingTranscriberWidget,
TemperatureValidator, TranscriptionTasksTableWidget, HuggingFaceSearchLineEdit,
TranscriptionOptionsGroupBox)
from buzz.settings.settings import Settings
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
from buzz.model_loader import ModelType
from buzz.settings.settings import Settings
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
TranscriptionOptions)
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
from tests.mock_sounddevice import MockInputStream, mock_query_devices
from .mock_qt import MockNetworkAccessManager, MockNetworkReply
@ -84,33 +83,6 @@ class TestAudioDevicesComboBox:
assert audio_devices_combo_box.currentText() == 'Background Music'
class TestDownloadModelProgressDialog:
def test_should_show_dialog(self, qtbot: QtBot):
dialog = DownloadModelProgressDialog(parent=None)
qtbot.add_widget(dialog)
assert dialog.labelText() == 'Downloading model (0%, unknown time remaining)'
def test_should_update_label_on_progress(self, qtbot: QtBot):
dialog = DownloadModelProgressDialog(parent=None)
qtbot.add_widget(dialog)
dialog.set_fraction_completed(0.0)
dialog.set_fraction_completed(0.01)
logging.debug(dialog.labelText())
assert dialog.labelText().startswith(
'Downloading model (1%')
dialog.set_fraction_completed(0.1)
assert dialog.labelText().startswith(
'Downloading model (10%')
# Other windows should not be processing while models are being downloaded
def test_should_be_an_application_modal(self, qtbot: QtBot):
dialog = DownloadModelProgressDialog(parent=None)
qtbot.add_widget(dialog)
assert dialog.windowModality() == Qt.WindowModality.ApplicationModal
@pytest.fixture()
def tasks_cache(tmp_path, request: SubRequest):
cache = TasksCache(cache_dir=str(tmp_path))

View file

@ -1,14 +1,18 @@
from buzz.model_loader import ModelLoader, TranscriptionModel
from buzz.model_loader import TranscriptionModel, get_local_model_path, ModelDownloader
def get_model_path(transcription_model: TranscriptionModel) -> str:
model_loader = ModelLoader(model=transcription_model)
path = get_local_model_path(model=transcription_model)
if path is not None:
return path
model_loader = ModelDownloader(model=transcription_model)
model_path = ''
def on_load_model(path: str):
nonlocal model_path
model_path = path
model_loader.finished.connect(on_load_model)
model_loader.signals.finished.connect(on_load_model)
model_loader.run()
return model_path

View file

@ -1,4 +1,3 @@
import logging
import os
import pathlib
import platform
@ -8,10 +7,10 @@ from typing import List
from unittest.mock import Mock, patch
import pytest
from PyQt6.QtCore import QThread, QCoreApplication
from PyQt6.QtCore import QThread
from pytestqt.qtbot import QtBot
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel, ModelLoader
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber,
Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
WhisperFileTranscriber,
@ -28,20 +27,14 @@ class TestRecordingTranscriber:
transcription_model = TranscriptionModel(model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY)
model_loader = ModelLoader(model=transcription_model)
model_loader.moveToThread(thread)
transcriber = RecordingTranscriber(transcription_options=TranscriptionOptions(
model=transcription_model, language='fr', task=Task.TRANSCRIBE),
input_device_index=0, sample_rate=16_000)
transcriber.moveToThread(thread)
thread.started.connect(model_loader.run)
thread.finished.connect(thread.deleteLater)
model_loader.finished.connect(transcriber.start)
model_loader.finished.connect(model_loader.deleteLater)
mock_transcription = Mock()
transcriber.transcription.connect(mock_transcription)
@ -66,7 +59,7 @@ class TestWhisperCppFileTranscriber:
[
(False, [Segment(0, 6560,
'Bienvenue dans Passe-Relle. Un podcast pensé pour')]),
(True, [Segment(0, 30, ''), Segment(30, 330, 'Bien'), Segment(330, 740, 'venue')])
(True, [Segment(30, 330, 'Bien'), Segment(330, 740, 'venue')])
])
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
file_transcription_options = FileTranscriptionOptions(
@ -89,7 +82,7 @@ class TestWhisperCppFileTranscriber:
transcriber.run()
mock_progress.assert_called()
segments = mock_completed.call_args[0][0]
segments = [segment for segment in mock_completed.call_args[0][0] if len(segment.text) > 0]
for i, expected_segment in enumerate(expected_segments):
assert expected_segment.start == segments[i].start
assert expected_segment.end == segments[i].end

View file

@ -0,0 +1,30 @@
from PyQt6.QtCore import Qt
from buzz.model_loader import ModelType
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
class TestModelDownloadProgressDialog:
def test_should_show_dialog(self, qtbot):
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
qtbot.add_widget(dialog)
assert dialog.labelText() == 'Downloading model (0%)'
def test_should_update_label_on_progress(self, qtbot):
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
qtbot.add_widget(dialog)
dialog.set_value(0.0)
dialog.set_value(0.01)
assert dialog.labelText().startswith(
'Downloading model (1%')
dialog.set_value(0.1)
assert dialog.labelText().startswith(
'Downloading model (10%')
# Other windows should not be processing while models are being downloaded
def test_should_be_an_application_modal(self, qtbot):
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
qtbot.add_widget(dialog)
assert dialog.windowModality() == Qt.WindowModality.WindowModal

View file

@ -0,0 +1,14 @@
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
class TestModelTypeComboBox:
def test_should_display_items(self, qtbot):
widget = ModelTypeComboBox()
qtbot.add_widget(widget)
assert widget.count() == 5
assert widget.itemText(0) == 'Whisper'
assert widget.itemText(1) == 'Whisper.cpp'
assert widget.itemText(2) == 'Hugging Face'
assert widget.itemText(3) == 'Faster Whisper'
assert widget.itemText(4) == 'OpenAI Whisper API'

View file

@ -0,0 +1,77 @@
import os
import pytest
from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import QComboBox, QPushButton
from buzz.model_loader import get_whisper_file_path, WhisperModelSize, get_local_model_path, TranscriptionModel, \
ModelType
from buzz.widgets.models_preferences_widget import ModelsPreferencesWidget
class TestModelsPreferencesWidget:
@pytest.fixture(scope='class')
def clear_model_cache(self):
file_path = get_whisper_file_path(size=WhisperModelSize.TINY)
if os.path.isfile(file_path):
os.remove(file_path)
def test_should_show_model_list(self, qtbot):
widget = ModelsPreferencesWidget()
qtbot.add_widget(widget)
first_item = widget.model_list_widget.topLevelItem(0)
assert first_item.text(0) == 'Downloaded'
second_item = widget.model_list_widget.topLevelItem(1)
assert second_item.text(0) == 'Available for Download'
def test_should_change_model_type(self, qtbot):
widget = ModelsPreferencesWidget()
qtbot.add_widget(widget)
combo_box = widget.findChild(QComboBox)
assert isinstance(combo_box, QComboBox)
combo_box.setCurrentText('Faster Whisper')
first_item = widget.model_list_widget.topLevelItem(0)
assert first_item.text(0) == 'Downloaded'
second_item = widget.model_list_widget.topLevelItem(1)
assert second_item.text(0) == 'Available for Download'
def test_should_download_model(self, qtbot, clear_model_cache):
# make progress dialog non-modal to unblock qtbot.wait_until
widget = ModelsPreferencesWidget(progress_dialog_modality=Qt.WindowModality.NonModal)
qtbot.add_widget(widget)
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
assert get_local_model_path(model=model) is None
available_item = widget.model_list_widget.topLevelItem(1)
assert available_item.text(0) == 'Available for Download'
tiny_item = available_item.child(0)
assert tiny_item.text(0) == 'Tiny'
tiny_item.setSelected(True)
download_button = widget.findChild(QPushButton)
assert isinstance(download_button, QPushButton)
assert download_button.text() == 'Download'
download_button.click()
def downloaded_model():
assert not download_button.isEnabled()
_downloaded_item = widget.model_list_widget.topLevelItem(0)
assert _downloaded_item.childCount() > 0
assert _downloaded_item.child(0).text(0) == 'Tiny'
_available_item = widget.model_list_widget.topLevelItem(1)
assert _available_item.childCount() == 0 or _available_item.child(0).text(0) != 'Tiny'
assert os.path.isfile(get_whisper_file_path(size=model.whisper_model_size))
qtbot.wait_until(callback=downloaded_model, timeout=60_000)

View file

@ -13,6 +13,7 @@ class TestPreferencesDialog:
tab_widget = dialog.findChild(QTabWidget)
assert isinstance(tab_widget, QTabWidget)
assert tab_widget.count() == 2
assert tab_widget.count() == 3
assert tab_widget.tabText(0) == 'General'
assert tab_widget.tabText(1) == 'Shortcuts'
assert tab_widget.tabText(1) == 'Models'
assert tab_widget.tabText(2) == 'Shortcuts'