mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-29 05:00:21 +02:00
Add models preferences (#421)
This commit is contained in:
parent
0a4be2b195
commit
cb5ad74620
146
buzz/gui.py
146
buzz/gui.py
|
@ -12,16 +12,15 @@ import sounddevice
|
|||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (QObject, Qt, QThread,
|
||||
QTimer, QUrl, pyqtSignal, QModelIndex, QPoint,
|
||||
QUrlQuery, QMetaObject, QEvent)
|
||||
QUrlQuery, QMetaObject, QEvent, QThreadPool)
|
||||
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
||||
QKeySequence, QPixmap, QTextCursor, QValidator, QKeyEvent, QPainter, QColor)
|
||||
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
|
||||
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
||||
QDialogButtonBox, QFileDialog, QLabel, QLineEdit,
|
||||
QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QDialogButtonBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QGroupBox, QTableWidget,
|
||||
QMenuBar, QFormLayout, QTableWidgetItem,
|
||||
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)
|
||||
QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)
|
||||
from whisper import tokenizer
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
|
@ -30,7 +29,8 @@ from .action import Action
|
|||
from .assets import get_asset_path
|
||||
from .icon import Icon
|
||||
from .locale import _
|
||||
from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel
|
||||
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, get_local_model_path, \
|
||||
ModelDownloader
|
||||
from .paths import file_paths_as_title
|
||||
from .recording import RecordingAmplitudeListener
|
||||
from .settings.settings import Settings, APP_NAME
|
||||
|
@ -43,6 +43,8 @@ from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Ou
|
|||
FileTranscriberQueueWorker, FileTranscriptionTask, RecordingTranscriber, LOADED_WHISPER_DLL,
|
||||
DEFAULT_WHISPER_TEMPERATURE)
|
||||
from .widgets.line_edit import LineEdit
|
||||
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
from .widgets.model_type_combo_box import ModelTypeComboBox
|
||||
from .widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
|
||||
from .widgets.preferences_dialog import PreferencesDialog
|
||||
from .widgets.toolbar import ToolBar
|
||||
|
@ -163,33 +165,6 @@ class RecordButton(QPushButton):
|
|||
self.setDefault(False)
|
||||
|
||||
|
||||
class DownloadModelProgressDialog(QProgressDialog):
|
||||
start_time: datetime
|
||||
|
||||
def __init__(self, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(_('Downloading model (0%, unknown time remaining)'),
|
||||
_('Cancel'), 0, 100, parent, *args)
|
||||
|
||||
# Setting this to a high value to avoid showing the dialog for models that
|
||||
# are checked locally but set progress to 0 immediately, i.e. Hugging Face or Faster Whisper models
|
||||
self.setMinimumDuration(10_000)
|
||||
|
||||
self.setWindowModality(Qt.WindowModality.ApplicationModal)
|
||||
self.start_time = datetime.now()
|
||||
self.setFixedSize(self.size())
|
||||
|
||||
def set_fraction_completed(self, fraction_completed: float) -> None:
|
||||
self.setValue(int(fraction_completed * self.maximum()))
|
||||
|
||||
if fraction_completed > 0.0:
|
||||
time_spent = (datetime.now() - self.start_time).total_seconds()
|
||||
time_left = (time_spent / fraction_completed) - time_spent
|
||||
|
||||
self.setLabelText(_('Downloading model') +
|
||||
f' ({fraction_completed :.0%}, ' +
|
||||
humanize.naturaldelta(time_left) + ')')
|
||||
|
||||
|
||||
def show_model_download_error_dialog(parent: QWidget, error: str):
|
||||
message = parent.tr(
|
||||
'An error occurred while loading the Whisper model') + \
|
||||
|
@ -200,9 +175,8 @@ def show_model_download_error_dialog(parent: QWidget, error: str):
|
|||
|
||||
|
||||
class FileTranscriberWidget(QWidget):
|
||||
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
model_loader: Optional[ModelLoader] = None
|
||||
transcriber_thread: Optional[QThread] = None
|
||||
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
|
||||
model_loader: Optional[ModelDownloader] = None
|
||||
file_transcription_options: FileTranscriptionOptions
|
||||
transcription_options: TranscriptionOptions
|
||||
is_transcribing = False
|
||||
|
@ -291,23 +265,16 @@ class FileTranscriberWidget(QWidget):
|
|||
def on_click_run(self):
|
||||
self.run_button.setDisabled(True)
|
||||
|
||||
self.transcriber_thread = QThread()
|
||||
self.model_loader = ModelLoader(model=self.transcription_options.model)
|
||||
self.model_loader.moveToThread(self.transcriber_thread)
|
||||
model_path = get_local_model_path(model=self.transcription_options.model)
|
||||
if model_path is not None:
|
||||
self.on_model_loaded(model_path)
|
||||
return
|
||||
|
||||
self.transcriber_thread.started.connect(self.model_loader.run)
|
||||
self.model_loader.finished.connect(
|
||||
self.transcriber_thread.quit)
|
||||
|
||||
self.model_loader.progress.connect(self.on_download_model_progress)
|
||||
|
||||
self.model_loader.error.connect(self.on_download_model_error)
|
||||
self.model_loader.error.connect(self.model_loader.deleteLater)
|
||||
|
||||
self.model_loader.finished.connect(self.on_model_loaded)
|
||||
self.model_loader.finished.connect(self.model_loader.deleteLater)
|
||||
|
||||
self.transcriber_thread.start()
|
||||
self.model_loader = ModelDownloader(model=self.transcription_options.model)
|
||||
self.model_loader.signals.progress.connect(self.on_download_model_progress)
|
||||
self.model_loader.signals.error.connect(self.on_download_model_error)
|
||||
self.model_loader.signals.finished.connect(self.on_model_loaded)
|
||||
QThreadPool().globalInstance().start(self.model_loader)
|
||||
|
||||
def on_model_loaded(self, model_path: str):
|
||||
self.reset_transcriber_controls()
|
||||
|
@ -320,12 +287,13 @@ class FileTranscriberWidget(QWidget):
|
|||
(current_size, total_size) = progress
|
||||
|
||||
if self.model_download_progress_dialog is None:
|
||||
self.model_download_progress_dialog = DownloadModelProgressDialog(parent=self)
|
||||
self.model_download_progress_dialog = ModelDownloadProgressDialog(
|
||||
model_type=self.transcription_options.model.model_type, parent=self)
|
||||
self.model_download_progress_dialog.canceled.connect(
|
||||
self.on_cancel_model_progress_dialog)
|
||||
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
|
||||
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
self.reset_model_download()
|
||||
|
@ -337,7 +305,7 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.stop()
|
||||
self.model_loader.cancel()
|
||||
self.reset_model_download()
|
||||
|
||||
def reset_model_download(self):
|
||||
|
@ -349,8 +317,8 @@ class FileTranscriberWidget(QWidget):
|
|||
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
|
||||
|
||||
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
|
||||
if self.transcriber_thread is not None:
|
||||
self.transcriber_thread.wait()
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.cancel()
|
||||
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task)
|
||||
|
@ -432,9 +400,9 @@ class RecordingTranscriberWidget(QWidget):
|
|||
current_status: 'RecordingStatus'
|
||||
transcription_options: TranscriptionOptions
|
||||
selected_device_id: Optional[int]
|
||||
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
|
||||
transcriber: Optional[RecordingTranscriber] = None
|
||||
model_loader: Optional[ModelLoader] = None
|
||||
model_loader: Optional[ModelDownloader] = None
|
||||
transcription_thread: Optional[QThread] = None
|
||||
recording_amplitude_listener: Optional[RecordingAmplitudeListener] = None
|
||||
device_sample_rate: Optional[int] = None
|
||||
|
@ -541,29 +509,35 @@ class RecordingTranscriberWidget(QWidget):
|
|||
def start_recording(self):
|
||||
self.record_button.setDisabled(True)
|
||||
|
||||
model_path = get_local_model_path(model=self.transcription_options.model)
|
||||
if model_path is not None:
|
||||
self.on_model_loaded(model_path)
|
||||
return
|
||||
|
||||
self.model_loader = ModelDownloader(model=self.transcription_options.model)
|
||||
self.model_loader.signals.progress.connect(self.on_download_model_progress)
|
||||
self.model_loader.signals.error.connect(self.on_download_model_error)
|
||||
self.model_loader.signals.finished.connect(self.on_model_loaded)
|
||||
QThreadPool().globalInstance().start(self.model_loader)
|
||||
|
||||
def on_model_loaded(self, model_path: str):
|
||||
self.reset_recording_controls()
|
||||
self.model_loader = None
|
||||
|
||||
self.transcription_thread = QThread()
|
||||
|
||||
self.model_loader = ModelLoader(model=self.transcription_options.model)
|
||||
# TODO: make runnable
|
||||
self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id,
|
||||
sample_rate=self.device_sample_rate,
|
||||
transcription_options=self.transcription_options)
|
||||
transcription_options=self.transcription_options,
|
||||
model_path=model_path)
|
||||
|
||||
self.model_loader.moveToThread(self.transcription_thread)
|
||||
self.transcriber.moveToThread(self.transcription_thread)
|
||||
|
||||
self.transcription_thread.started.connect(self.model_loader.run)
|
||||
self.transcription_thread.started.connect(self.transcriber.start)
|
||||
self.transcription_thread.finished.connect(
|
||||
self.transcription_thread.deleteLater)
|
||||
|
||||
self.model_loader.finished.connect(self.reset_recording_controls)
|
||||
self.model_loader.finished.connect(self.transcriber.start)
|
||||
self.model_loader.finished.connect(self.model_loader.deleteLater)
|
||||
|
||||
self.model_loader.progress.connect(
|
||||
self.on_download_model_progress)
|
||||
|
||||
self.model_loader.error.connect(self.on_download_model_error)
|
||||
|
||||
self.transcriber.transcription.connect(self.on_next_transcription)
|
||||
|
||||
self.transcriber.finished.connect(self.on_transcriber_finished)
|
||||
|
@ -580,12 +554,13 @@ class RecordingTranscriberWidget(QWidget):
|
|||
(current_size, total_size) = progress
|
||||
|
||||
if self.model_download_progress_dialog is None:
|
||||
self.model_download_progress_dialog = DownloadModelProgressDialog(parent=self)
|
||||
self.model_download_progress_dialog = ModelDownloadProgressDialog(
|
||||
model_type=self.transcription_options.model.model_type, parent=self)
|
||||
self.model_download_progress_dialog.canceled.connect(
|
||||
self.on_cancel_model_progress_dialog)
|
||||
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
|
||||
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)
|
||||
|
||||
def set_recording_status_stopped(self):
|
||||
self.record_button.set_stopped()
|
||||
|
@ -640,9 +615,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
# Clear text box placeholder because the first chunk takes a while to process
|
||||
self.text_box.setPlaceholderText('')
|
||||
self.reset_record_button()
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.close()
|
||||
self.model_download_progress_dialog = None
|
||||
self.reset_model_download()
|
||||
|
||||
def reset_record_button(self):
|
||||
self.record_button.setEnabled(True)
|
||||
|
@ -651,6 +624,9 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.audio_meter_widget.update_amplitude(amplitude)
|
||||
|
||||
def closeEvent(self, event: QCloseEvent) -> None:
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.cancel()
|
||||
|
||||
self.stop_recording()
|
||||
if self.recording_amplitude_listener is not None:
|
||||
self.recording_amplitude_listener.stop_recording()
|
||||
|
@ -1264,17 +1240,10 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
|
||||
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)
|
||||
|
||||
self.model_type_combo_box = QComboBox(self)
|
||||
if model_types is None:
|
||||
model_types = [model_type for model_type in ModelType]
|
||||
for model_type in model_types:
|
||||
# Hide Whisper.cpp option is whisper.dll did not load correctly.
|
||||
# See: https://github.com/chidiwilliams/buzz/issues/274, https://github.com/chidiwilliams/buzz/issues/197
|
||||
if model_type == ModelType.WHISPER_CPP and LOADED_WHISPER_DLL is False:
|
||||
continue
|
||||
self.model_type_combo_box.addItem(model_type.value)
|
||||
self.model_type_combo_box.setCurrentText(default_transcription_options.model.model_type.value)
|
||||
self.model_type_combo_box.currentTextChanged.connect(self.on_model_type_changed)
|
||||
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
|
||||
default_model=default_transcription_options.model.model_type,
|
||||
parent=self)
|
||||
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
|
||||
|
||||
self.whisper_model_size_combo_box = QComboBox(self)
|
||||
self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize])
|
||||
|
@ -1339,8 +1308,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
model_type == ModelType.FASTER_WHISPER))
|
||||
self.form_layout.setRowVisible(self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API)
|
||||
|
||||
def on_model_type_changed(self, text: str):
|
||||
model_type = ModelType(text)
|
||||
def on_model_type_changed(self, model_type: ModelType):
|
||||
self.transcription_options.model.model_type = model_type
|
||||
self.reset_visible_rows()
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
|
|
@ -2,17 +2,18 @@ import enum
|
|||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import faster_whisper
|
||||
import huggingface_hub
|
||||
import requests
|
||||
import whisper
|
||||
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
|
||||
from PyQt6.QtCore import QObject, pyqtSignal, QRunnable
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
from buzz import transformers_whisper
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class WhisperModelSize(enum.Enum):
|
||||
|
@ -51,121 +52,185 @@ def get_hugging_face_dataset_file_url(author: str, repository_name: str, filenam
|
|||
return f'https://huggingface.co/datasets/{author}/{repository_name}/resolve/main/{filename}'
|
||||
|
||||
|
||||
class ModelLoader(QObject):
|
||||
progress = pyqtSignal(tuple) # (current, total)
|
||||
finished = pyqtSignal(str)
|
||||
error = pyqtSignal(str)
|
||||
stopped = False
|
||||
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
|
||||
root_dir = user_cache_dir('Buzz')
|
||||
return os.path.join(root_dir, f'ggml-model-whisper-{size.value}.bin')
|
||||
|
||||
def __init__(self, model: TranscriptionModel, parent: Optional['QObject'] = None) -> None:
|
||||
super().__init__(parent)
|
||||
self.model_type = model.model_type
|
||||
self.whisper_model_size = model.whisper_model_size
|
||||
self.hugging_face_model_id = model.hugging_face_model_id
|
||||
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
if self.model_type == ModelType.WHISPER_CPP:
|
||||
root_dir = user_cache_dir('Buzz')
|
||||
model_name = self.whisper_model_size.value
|
||||
def get_whisper_file_path(size: WhisperModelSize) -> str:
|
||||
root_dir = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
|
||||
url = whisper._MODELS[size.value]
|
||||
return os.path.join(root_dir, os.path.basename(url))
|
||||
|
||||
|
||||
def get_local_model_path(model: TranscriptionModel) -> Optional[str]:
|
||||
if model.model_type == ModelType.WHISPER_CPP:
|
||||
file_path = get_whisper_cpp_file_path(size=model.whisper_model_size)
|
||||
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
||||
return None
|
||||
return file_path
|
||||
|
||||
if model.model_type == ModelType.WHISPER:
|
||||
file_path = get_whisper_file_path(size=model.whisper_model_size)
|
||||
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
||||
return None
|
||||
return file_path
|
||||
|
||||
if model.model_type == ModelType.FASTER_WHISPER:
|
||||
try:
|
||||
return download_faster_whisper_model(size=model.whisper_model_size.value, local_files_only=True)
|
||||
except (ValueError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
if model.model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
return ''
|
||||
|
||||
if model.model_type == ModelType.HUGGING_FACE:
|
||||
try:
|
||||
return huggingface_hub.snapshot_download(model.hugging_face_model_id, local_files_only=True)
|
||||
except (ValueError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
raise Exception("Unknown model type")
|
||||
|
||||
|
||||
def download_faster_whisper_model(size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None):
|
||||
if size not in faster_whisper.utils._MODELS:
|
||||
raise ValueError(
|
||||
"Invalid model size '%s', expected one of: %s" % (size, ", ".join(faster_whisper.utils._MODELS))
|
||||
)
|
||||
|
||||
repo_id = "guillaumekln/faster-whisper-%s" % size
|
||||
|
||||
allow_patterns = [
|
||||
"config.json",
|
||||
"model.bin",
|
||||
"tokenizer.json",
|
||||
"vocabulary.txt",
|
||||
]
|
||||
|
||||
return huggingface_hub.snapshot_download(repo_id, allow_patterns=allow_patterns, local_files_only=local_files_only,
|
||||
tqdm_class=tqdm_class)
|
||||
|
||||
|
||||
class ModelDownloader(QRunnable):
|
||||
class Signals(QObject):
|
||||
finished = pyqtSignal(str)
|
||||
progress = pyqtSignal(tuple) # (current, total)
|
||||
error = pyqtSignal(str)
|
||||
|
||||
def __init__(self, model: TranscriptionModel):
|
||||
super().__init__()
|
||||
|
||||
self.signals = self.Signals()
|
||||
self.model = model
|
||||
self.stopped = False
|
||||
|
||||
def run(self) -> None:
|
||||
if self.model.model_type == ModelType.WHISPER_CPP:
|
||||
model_name = self.model.whisper_model_size.value
|
||||
url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp',
|
||||
filename=f'ggml-{model_name}.bin')
|
||||
file_path = os.path.join(root_dir, f'ggml-model-whisper-{model_name}.bin')
|
||||
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
|
||||
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
|
||||
self.download_model(url, file_path, expected_sha256)
|
||||
return self.download_model_to_path(url=url, file_path=file_path, expected_sha256=expected_sha256)
|
||||
|
||||
elif self.model_type == ModelType.WHISPER:
|
||||
root_dir = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
model_name = self.whisper_model_size.value
|
||||
url = whisper._MODELS[model_name]
|
||||
file_path = os.path.join(root_dir, os.path.basename(url))
|
||||
if self.model.model_type == ModelType.WHISPER:
|
||||
url = whisper._MODELS[self.model.whisper_model_size.value]
|
||||
file_path = get_whisper_file_path(size=self.model.whisper_model_size)
|
||||
expected_sha256 = url.split('/')[-2]
|
||||
self.download_model(url, file_path, expected_sha256)
|
||||
return self.download_model_to_path(url=url, file_path=file_path, expected_sha256=expected_sha256)
|
||||
|
||||
elif self.model_type == ModelType.HUGGING_FACE:
|
||||
self.progress.emit((0, 100))
|
||||
progress = self.signals.progress
|
||||
|
||||
try:
|
||||
# Loads the model from cache or download if not in cache
|
||||
transformers_whisper.load_model(self.hugging_face_model_id)
|
||||
except (FileNotFoundError, EnvironmentError) as exception:
|
||||
self.error.emit(f'{exception}')
|
||||
return
|
||||
# gross abuse of power...
|
||||
class _tqdm(tqdm):
|
||||
def update(self, n: float | None = ...) -> bool | None:
|
||||
progress.emit((n, self.total))
|
||||
return super().update(n)
|
||||
|
||||
self.progress.emit((100, 100))
|
||||
file_path = self.hugging_face_model_id
|
||||
def close(self):
|
||||
progress.emit((self.n, self.total))
|
||||
return super().close()
|
||||
|
||||
elif self.model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
file_path = ""
|
||||
if self.model.model_type == ModelType.FASTER_WHISPER:
|
||||
model_path = download_faster_whisper_model(size=self.model.whisper_model_size.value, tqdm_class=_tqdm)
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
||||
elif self.model_type == ModelType.FASTER_WHISPER:
|
||||
self.progress.emit((0, 100))
|
||||
file_path = faster_whisper.download_model(size=self.whisper_model_size.value)
|
||||
self.progress.emit((100, 100))
|
||||
if self.model.model_type == ModelType.HUGGING_FACE:
|
||||
model_path = huggingface_hub.snapshot_download(self.model.hugging_face_model_id, tqdm_class=_tqdm)
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
||||
else:
|
||||
raise Exception("Invalid model type: " + self.model_type.value)
|
||||
if self.model.model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
self.signals.finished.emit('')
|
||||
return
|
||||
|
||||
self.finished.emit(file_path)
|
||||
raise Exception("Invalid model type: " + self.model.model_type.value)
|
||||
|
||||
def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]):
|
||||
def download_model_to_path(self, url: str, file_path: str, expected_sha256: Optional[str]):
|
||||
try:
|
||||
logging.debug(f'Downloading model from {url} to {file_path}')
|
||||
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
if os.path.exists(file_path) and not os.path.isfile(file_path):
|
||||
raise RuntimeError(
|
||||
f"{file_path} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
if expected_sha256 is None:
|
||||
return file_path
|
||||
|
||||
model_bytes = open(file_path, "rb").read()
|
||||
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
|
||||
if model_sha256 == expected_sha256:
|
||||
return file_path
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(file_path, 'wb') as output:
|
||||
source.raise_for_status()
|
||||
total_size = float(source.headers.get('Content-Length', 0))
|
||||
current = 0.0
|
||||
self.progress.emit((current, total_size))
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
if self.stopped:
|
||||
return
|
||||
output.write(chunk)
|
||||
current += len(chunk)
|
||||
self.progress.emit((current, total_size))
|
||||
|
||||
if expected_sha256 is not None:
|
||||
model_bytes = open(file_path, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
|
||||
"model.")
|
||||
|
||||
logging.debug('Downloaded model')
|
||||
|
||||
return file_path
|
||||
except RuntimeError as exc:
|
||||
self.error.emit(str(exc))
|
||||
logging.exception('')
|
||||
downloaded = self.download_model(url, file_path, expected_sha256)
|
||||
if downloaded:
|
||||
self.signals.finished.emit(file_path)
|
||||
except requests.RequestException:
|
||||
self.error.emit('A connection error occurred')
|
||||
logging.exception('')
|
||||
except Exception:
|
||||
self.error.emit('An unknown error occurred')
|
||||
self.signals.error.emit('A connection error occurred')
|
||||
logging.exception('')
|
||||
except Exception as exc:
|
||||
self.signals.error.emit(str(exc))
|
||||
logging.exception(exc)
|
||||
|
||||
def stop(self):
|
||||
def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]) -> bool:
|
||||
logging.debug(f'Downloading model from {url} to {file_path}')
|
||||
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
if os.path.exists(file_path) and not os.path.isfile(file_path):
|
||||
raise RuntimeError(
|
||||
f"{file_path} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
if expected_sha256 is None:
|
||||
return True
|
||||
|
||||
model_bytes = open(file_path, "rb").read()
|
||||
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
|
||||
if model_sha256 == expected_sha256:
|
||||
return True
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
tmp_file = tempfile.mktemp()
|
||||
logging.debug('Downloading to temporary file = %s', tmp_file)
|
||||
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(tmp_file, 'wb') as output:
|
||||
source.raise_for_status()
|
||||
total_size = float(source.headers.get('Content-Length', 0))
|
||||
current = 0.0
|
||||
self.signals.progress.emit((current, total_size))
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
if self.stopped:
|
||||
return False
|
||||
output.write(chunk)
|
||||
current += len(chunk)
|
||||
self.signals.progress.emit((current, total_size))
|
||||
|
||||
if expected_sha256 is not None:
|
||||
model_bytes = open(tmp_file, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
|
||||
"model.")
|
||||
|
||||
logging.debug('Downloaded model')
|
||||
|
||||
os.rename(tmp_file, file_path)
|
||||
logging.debug('Moved file from %s to %s', tmp_file, file_path)
|
||||
return True
|
||||
|
||||
def cancel(self):
|
||||
self.stopped = True
|
||||
|
|
|
@ -107,20 +107,23 @@ class RecordingTranscriber(QObject):
|
|||
MAX_QUEUE_SIZE = 10
|
||||
|
||||
def __init__(self, transcription_options: TranscriptionOptions,
|
||||
input_device_index: Optional[int], sample_rate: int, parent: Optional[QObject] = None) -> None:
|
||||
input_device_index: Optional[int], sample_rate: int, model_path: str,
|
||||
parent: Optional[QObject] = None) -> None:
|
||||
super().__init__(parent)
|
||||
self.transcription_options = transcription_options
|
||||
self.current_stream = None
|
||||
self.input_device_index = input_device_index
|
||||
self.sample_rate = sample_rate
|
||||
self.model_path = model_path
|
||||
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
|
||||
# pause queueing if more than 3 batches behind
|
||||
self.max_queue_size = 3 * self.n_batch_samples
|
||||
self.queue = np.ndarray([], dtype=np.float32)
|
||||
self.mutex = threading.Lock()
|
||||
|
||||
@pyqtSlot(str)
|
||||
def start(self, model_path: str):
|
||||
def start(self):
|
||||
model_path = self.model_path
|
||||
|
||||
if self.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
model = whisper.load_model(model_path)
|
||||
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
|
||||
|
|
46
buzz/widgets/model_download_progress_dialog.py
Normal file
46
buzz/widgets/model_download_progress_dialog.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import humanize
|
||||
from PyQt6.QtCore import Qt
|
||||
from PyQt6.QtWidgets import QProgressDialog, QWidget, QPushButton
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelType
|
||||
|
||||
|
||||
class ModelDownloadProgressDialog(QProgressDialog):
|
||||
def __init__(self, model_type: ModelType, parent: Optional[QWidget] = None, modality=Qt.WindowModality.WindowModal):
|
||||
super().__init__(parent)
|
||||
|
||||
self.cancelable = model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP
|
||||
self.start_time = datetime.now()
|
||||
self.setRange(0, 100)
|
||||
self.setMinimumDuration(0)
|
||||
self.setWindowModality(modality)
|
||||
self.update_label_text(0)
|
||||
|
||||
if not self.cancelable:
|
||||
cancel_button = QPushButton('Cancel', self)
|
||||
cancel_button.setEnabled(False)
|
||||
self.setCancelButton(cancel_button)
|
||||
|
||||
def update_label_text(self, fraction_completed: float):
|
||||
label_text = f"{_('Downloading model')} ({fraction_completed:.0%}"
|
||||
if fraction_completed > 0:
|
||||
time_spent = (datetime.now() - self.start_time).total_seconds()
|
||||
time_left = (time_spent / fraction_completed) - time_spent
|
||||
label_text += f', {humanize.naturaldelta(time_left)} remaining'
|
||||
label_text += ')'
|
||||
|
||||
self.setLabelText(label_text)
|
||||
|
||||
def set_value(self, fraction_completed: float):
|
||||
if self.wasCanceled():
|
||||
return
|
||||
self.setValue(int(fraction_completed * self.maximum()))
|
||||
self.update_label_text(fraction_completed)
|
||||
|
||||
def cancel(self) -> None:
|
||||
if self.cancelable:
|
||||
super().cancel()
|
32
buzz/widgets/model_type_combo_box.py
Normal file
32
buzz/widgets/model_type_combo_box.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from typing import Optional, List
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtWidgets import QComboBox, QWidget
|
||||
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.transcriber import LOADED_WHISPER_DLL
|
||||
|
||||
|
||||
class ModelTypeComboBox(QComboBox):
|
||||
changed = pyqtSignal(ModelType)
|
||||
|
||||
def __init__(self, model_types: Optional[List[ModelType]] = None, default_model: Optional[ModelType] = None,
|
||||
parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
|
||||
if model_types is None:
|
||||
model_types = [model_type for model_type in ModelType]
|
||||
|
||||
for model_type in model_types:
|
||||
# Hide Whisper.cpp option is whisper.dll did not load correctly.
|
||||
# See: https://github.com/chidiwilliams/buzz/issues/274, https://github.com/chidiwilliams/buzz/issues/197
|
||||
if model_type == ModelType.WHISPER_CPP and LOADED_WHISPER_DLL is False:
|
||||
continue
|
||||
self.addItem(model_type.value)
|
||||
|
||||
self.currentTextChanged.connect(self.on_text_changed)
|
||||
if default_model is not None:
|
||||
self.setCurrentText(default_model.value)
|
||||
|
||||
def on_text_changed(self, text: str):
|
||||
self.changed.emit(ModelType(text))
|
123
buzz/widgets/models_preferences_widget.py
Normal file
123
buzz/widgets/models_preferences_widget.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import Qt, QThreadPool
|
||||
from PyQt6.QtWidgets import QWidget, QFormLayout, QTreeWidget, QTreeWidgetItem, QPushButton, QMessageBox
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelType, WhisperModelSize, get_local_model_path, TranscriptionModel, ModelDownloader
|
||||
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
|
||||
|
||||
|
||||
class ModelsPreferencesWidget(QWidget):
|
||||
def __init__(self, progress_dialog_modality: Optional[Qt.WindowModality] = None, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.model_downloader: Optional[ModelDownloader] = None
|
||||
self.model = TranscriptionModel(model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY)
|
||||
self.progress_dialog_modality = progress_dialog_modality
|
||||
|
||||
self.progress_dialog: Optional[ModelDownloadProgressDialog] = None
|
||||
|
||||
layout = QFormLayout()
|
||||
model_type_combo_box = ModelTypeComboBox(
|
||||
model_types=[ModelType.WHISPER, ModelType.WHISPER_CPP, ModelType.FASTER_WHISPER],
|
||||
default_model=self.model.model_type, parent=self)
|
||||
model_type_combo_box.changed.connect(self.on_model_type_changed)
|
||||
layout.addRow('Group', model_type_combo_box)
|
||||
|
||||
self.model_list_widget = QTreeWidget()
|
||||
self.model_list_widget.setColumnCount(1)
|
||||
self.reset_model_size_list()
|
||||
self.model_list_widget.currentItemChanged.connect(self.on_model_size_changed)
|
||||
layout.addWidget(self.model_list_widget)
|
||||
|
||||
self.download_button = QPushButton(_('Download'))
|
||||
self.download_button.clicked.connect(self.on_download_button_clicked)
|
||||
self.reset_download_button()
|
||||
layout.addWidget(self.download_button)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def on_model_size_changed(self, current: QTreeWidgetItem, _: QTreeWidgetItem):
|
||||
if current is None:
|
||||
return
|
||||
item_data = current.data(0, Qt.ItemDataRole.UserRole)
|
||||
if item_data is None:
|
||||
return
|
||||
self.model.whisper_model_size = item_data
|
||||
self.reset_download_button()
|
||||
|
||||
def reset_download_button(self):
|
||||
model_path = get_local_model_path(model=self.model)
|
||||
self.download_button.setEnabled(model_path is None)
|
||||
|
||||
def on_model_type_changed(self, model_type: ModelType):
|
||||
self.model.model_type = model_type
|
||||
self.reset_model_size_list()
|
||||
self.reset_download_button()
|
||||
|
||||
def on_download_button_clicked(self):
|
||||
self.progress_dialog = ModelDownloadProgressDialog(model_type=self.model.model_type,
|
||||
modality=self.progress_dialog_modality, parent=self)
|
||||
self.progress_dialog.canceled.connect(self.on_progress_dialog_canceled)
|
||||
|
||||
self.download_button.setEnabled(False)
|
||||
|
||||
self.model_downloader = ModelDownloader(model=self.model)
|
||||
self.model_downloader.signals.finished.connect(self.on_download_completed)
|
||||
self.model_downloader.signals.progress.connect(self.on_download_progress)
|
||||
self.model_downloader.signals.error.connect(self.on_download_error)
|
||||
QThreadPool().globalInstance().start(self.model_downloader)
|
||||
|
||||
def on_download_completed(self, _: str):
|
||||
self.progress_dialog = None
|
||||
|
||||
self.reset_download_button()
|
||||
self.reset_model_size_list()
|
||||
|
||||
def on_download_error(self, error: str):
|
||||
self.progress_dialog.cancel()
|
||||
self.progress_dialog = None
|
||||
|
||||
self.reset_download_button()
|
||||
self.reset_model_size_list()
|
||||
QMessageBox.warning(self, _('Error'), f'Download failed: {error}')
|
||||
|
||||
def on_download_progress(self, progress: tuple):
|
||||
self.progress_dialog.set_value(float(progress[0]) / progress[1])
|
||||
|
||||
def reset_model_size_list(self):
|
||||
self.model_list_widget.clear()
|
||||
|
||||
downloaded_item = QTreeWidgetItem(self.model_list_widget)
|
||||
downloaded_item.setText(0, _('Downloaded'))
|
||||
downloaded_item.setFlags(
|
||||
downloaded_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
|
||||
|
||||
available_item = QTreeWidgetItem(self.model_list_widget)
|
||||
available_item.setText(0, _('Available for Download'))
|
||||
available_item.setFlags(
|
||||
available_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
|
||||
|
||||
self.model_list_widget.addTopLevelItems([downloaded_item, available_item])
|
||||
self.model_list_widget.expandToDepth(2)
|
||||
self.model_list_widget.setHeaderHidden(True)
|
||||
self.model_list_widget.setAlternatingRowColors(True)
|
||||
|
||||
for model_size in WhisperModelSize:
|
||||
model_path = get_local_model_path(
|
||||
model=TranscriptionModel(model_type=self.model.model_type, whisper_model_size=model_size))
|
||||
parent = downloaded_item if model_path is not None else available_item
|
||||
item = QTreeWidgetItem(parent)
|
||||
item.setText(0, model_size.value.title())
|
||||
item.setData(0, Qt.ItemDataRole.UserRole, model_size)
|
||||
if self.model.whisper_model_size == model_size:
|
||||
item.setSelected(True)
|
||||
parent.addChild(item)
|
||||
|
||||
def on_progress_dialog_canceled(self):
|
||||
self.model_downloader.cancel()
|
||||
self.reset_model_size_list()
|
||||
self.reset_download_button()
|
|
@ -4,8 +4,8 @@ from PyQt6.QtCore import pyqtSignal
|
|||
from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QTabWidget, QDialogButtonBox
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.store.keyring_store import KeyringStore
|
||||
from buzz.widgets.general_preferences_widget import GeneralPreferencesWidget
|
||||
from buzz.widgets.models_preferences_widget import ModelsPreferencesWidget
|
||||
from buzz.widgets.shortcuts_editor_preferences_widget import ShortcutsEditorPreferencesWidget
|
||||
|
||||
|
||||
|
@ -25,6 +25,9 @@ class PreferencesDialog(QDialog):
|
|||
general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed)
|
||||
tab_widget.addTab(general_tab_widget, _('General'))
|
||||
|
||||
models_tab_widget = ModelsPreferencesWidget(parent=self)
|
||||
tab_widget.addTab(models_tab_widget, _('Models'))
|
||||
|
||||
shortcuts_table_widget = ShortcutsEditorPreferencesWidget(shortcuts, self)
|
||||
shortcuts_table_widget.shortcuts_changed.connect(self.shortcuts_changed)
|
||||
tab_widget.addTab(shortcuts_table_widget, _('Shortcuts'))
|
||||
|
|
13
main.py
13
main.py
|
@ -38,10 +38,15 @@ if __name__ == "__main__":
|
|||
|
||||
log_dir = user_log_dir(appname='Buzz')
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
logging.basicConfig(
|
||||
filename=os.path.join(log_dir, 'logs.txt'),
|
||||
level=logging.DEBUG,
|
||||
format="[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s")
|
||||
|
||||
log_format = "[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s"
|
||||
logging.basicConfig(filename=os.path.join(log_dir, 'logs.txt'), level=logging.DEBUG, format=log_format)
|
||||
|
||||
if getattr(sys, 'frozen', False) is False:
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(stdout_handler)
|
||||
|
||||
from buzz.gui import Application
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
import multiprocessing
|
||||
import os.path
|
||||
import platform
|
||||
|
@ -15,16 +14,16 @@ from pytestqt.qtbot import QtBot
|
|||
|
||||
from buzz.__version__ import VERSION
|
||||
from buzz.cache import TasksCache
|
||||
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, AudioDevicesComboBox, DownloadModelProgressDialog,
|
||||
FileTranscriberWidget, LanguagesComboBox, MainWindow,
|
||||
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, AudioDevicesComboBox, FileTranscriberWidget,
|
||||
LanguagesComboBox, MainWindow,
|
||||
RecordingTranscriberWidget,
|
||||
TemperatureValidator, TranscriptionTasksTableWidget, HuggingFaceSearchLineEdit,
|
||||
TranscriptionOptionsGroupBox)
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
|
||||
TranscriptionOptions)
|
||||
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
|
||||
from tests.mock_sounddevice import MockInputStream, mock_query_devices
|
||||
from .mock_qt import MockNetworkAccessManager, MockNetworkReply
|
||||
|
||||
|
@ -84,33 +83,6 @@ class TestAudioDevicesComboBox:
|
|||
assert audio_devices_combo_box.currentText() == 'Background Music'
|
||||
|
||||
|
||||
class TestDownloadModelProgressDialog:
|
||||
def test_should_show_dialog(self, qtbot: QtBot):
|
||||
dialog = DownloadModelProgressDialog(parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
assert dialog.labelText() == 'Downloading model (0%, unknown time remaining)'
|
||||
|
||||
def test_should_update_label_on_progress(self, qtbot: QtBot):
|
||||
dialog = DownloadModelProgressDialog(parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
dialog.set_fraction_completed(0.0)
|
||||
|
||||
dialog.set_fraction_completed(0.01)
|
||||
logging.debug(dialog.labelText())
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading model (1%')
|
||||
|
||||
dialog.set_fraction_completed(0.1)
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading model (10%')
|
||||
|
||||
# Other windows should not be processing while models are being downloaded
|
||||
def test_should_be_an_application_modal(self, qtbot: QtBot):
|
||||
dialog = DownloadModelProgressDialog(parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
assert dialog.windowModality() == Qt.WindowModality.ApplicationModal
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tasks_cache(tmp_path, request: SubRequest):
|
||||
cache = TasksCache(cache_dir=str(tmp_path))
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
from buzz.model_loader import ModelLoader, TranscriptionModel
|
||||
from buzz.model_loader import TranscriptionModel, get_local_model_path, ModelDownloader
|
||||
|
||||
|
||||
def get_model_path(transcription_model: TranscriptionModel) -> str:
|
||||
model_loader = ModelLoader(model=transcription_model)
|
||||
path = get_local_model_path(model=transcription_model)
|
||||
if path is not None:
|
||||
return path
|
||||
|
||||
model_loader = ModelDownloader(model=transcription_model)
|
||||
model_path = ''
|
||||
|
||||
def on_load_model(path: str):
|
||||
nonlocal model_path
|
||||
model_path = path
|
||||
|
||||
model_loader.finished.connect(on_load_model)
|
||||
model_loader.signals.finished.connect(on_load_model)
|
||||
model_loader.run()
|
||||
return model_path
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
|
@ -8,10 +7,10 @@ from typing import List
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from PyQt6.QtCore import QThread, QCoreApplication
|
||||
from PyQt6.QtCore import QThread
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel, ModelLoader
|
||||
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber,
|
||||
Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
|
||||
WhisperFileTranscriber,
|
||||
|
@ -28,20 +27,14 @@ class TestRecordingTranscriber:
|
|||
|
||||
transcription_model = TranscriptionModel(model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY)
|
||||
model_loader = ModelLoader(model=transcription_model)
|
||||
model_loader.moveToThread(thread)
|
||||
|
||||
transcriber = RecordingTranscriber(transcription_options=TranscriptionOptions(
|
||||
model=transcription_model, language='fr', task=Task.TRANSCRIBE),
|
||||
input_device_index=0, sample_rate=16_000)
|
||||
transcriber.moveToThread(thread)
|
||||
|
||||
thread.started.connect(model_loader.run)
|
||||
thread.finished.connect(thread.deleteLater)
|
||||
|
||||
model_loader.finished.connect(transcriber.start)
|
||||
model_loader.finished.connect(model_loader.deleteLater)
|
||||
|
||||
mock_transcription = Mock()
|
||||
transcriber.transcription.connect(mock_transcription)
|
||||
|
||||
|
@ -66,7 +59,7 @@ class TestWhisperCppFileTranscriber:
|
|||
[
|
||||
(False, [Segment(0, 6560,
|
||||
'Bienvenue dans Passe-Relle. Un podcast pensé pour')]),
|
||||
(True, [Segment(0, 30, ''), Segment(30, 330, 'Bien'), Segment(330, 740, 'venue')])
|
||||
(True, [Segment(30, 330, 'Bien'), Segment(330, 740, 'venue')])
|
||||
])
|
||||
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
|
@ -89,7 +82,7 @@ class TestWhisperCppFileTranscriber:
|
|||
transcriber.run()
|
||||
|
||||
mock_progress.assert_called()
|
||||
segments = mock_completed.call_args[0][0]
|
||||
segments = [segment for segment in mock_completed.call_args[0][0] if len(segment.text) > 0]
|
||||
for i, expected_segment in enumerate(expected_segments):
|
||||
assert expected_segment.start == segments[i].start
|
||||
assert expected_segment.end == segments[i].end
|
||||
|
|
30
tests/widgets/model_download_progress_dialog.py
Normal file
30
tests/widgets/model_download_progress_dialog.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from PyQt6.QtCore import Qt
|
||||
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
|
||||
|
||||
class TestModelDownloadProgressDialog:
|
||||
def test_should_show_dialog(self, qtbot):
|
||||
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
assert dialog.labelText() == 'Downloading model (0%)'
|
||||
|
||||
def test_should_update_label_on_progress(self, qtbot):
|
||||
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
dialog.set_value(0.0)
|
||||
|
||||
dialog.set_value(0.01)
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading model (1%')
|
||||
|
||||
dialog.set_value(0.1)
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading model (10%')
|
||||
|
||||
# Other windows should not be processing while models are being downloaded
|
||||
def test_should_be_an_application_modal(self, qtbot):
|
||||
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
assert dialog.windowModality() == Qt.WindowModality.WindowModal
|
14
tests/widgets/model_type_combo_box_test.py
Normal file
14
tests/widgets/model_type_combo_box_test.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
|
||||
|
||||
|
||||
class TestModelTypeComboBox:
|
||||
def test_should_display_items(self, qtbot):
|
||||
widget = ModelTypeComboBox()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
assert widget.count() == 5
|
||||
assert widget.itemText(0) == 'Whisper'
|
||||
assert widget.itemText(1) == 'Whisper.cpp'
|
||||
assert widget.itemText(2) == 'Hugging Face'
|
||||
assert widget.itemText(3) == 'Faster Whisper'
|
||||
assert widget.itemText(4) == 'OpenAI Whisper API'
|
77
tests/widgets/models_preferences_widget_test.py
Normal file
77
tests/widgets/models_preferences_widget_test.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from PyQt6.QtCore import Qt
|
||||
from PyQt6.QtWidgets import QComboBox, QPushButton
|
||||
|
||||
from buzz.model_loader import get_whisper_file_path, WhisperModelSize, get_local_model_path, TranscriptionModel, \
|
||||
ModelType
|
||||
from buzz.widgets.models_preferences_widget import ModelsPreferencesWidget
|
||||
|
||||
|
||||
class TestModelsPreferencesWidget:
|
||||
@pytest.fixture(scope='class')
|
||||
def clear_model_cache(self):
|
||||
file_path = get_whisper_file_path(size=WhisperModelSize.TINY)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
def test_should_show_model_list(self, qtbot):
|
||||
widget = ModelsPreferencesWidget()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
first_item = widget.model_list_widget.topLevelItem(0)
|
||||
assert first_item.text(0) == 'Downloaded'
|
||||
|
||||
second_item = widget.model_list_widget.topLevelItem(1)
|
||||
assert second_item.text(0) == 'Available for Download'
|
||||
|
||||
def test_should_change_model_type(self, qtbot):
|
||||
widget = ModelsPreferencesWidget()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
combo_box = widget.findChild(QComboBox)
|
||||
assert isinstance(combo_box, QComboBox)
|
||||
combo_box.setCurrentText('Faster Whisper')
|
||||
|
||||
first_item = widget.model_list_widget.topLevelItem(0)
|
||||
assert first_item.text(0) == 'Downloaded'
|
||||
|
||||
second_item = widget.model_list_widget.topLevelItem(1)
|
||||
assert second_item.text(0) == 'Available for Download'
|
||||
|
||||
def test_should_download_model(self, qtbot, clear_model_cache):
|
||||
# make progress dialog non-modal to unblock qtbot.wait_until
|
||||
widget = ModelsPreferencesWidget(progress_dialog_modality=Qt.WindowModality.NonModal)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
|
||||
assert get_local_model_path(model=model) is None
|
||||
|
||||
available_item = widget.model_list_widget.topLevelItem(1)
|
||||
assert available_item.text(0) == 'Available for Download'
|
||||
|
||||
tiny_item = available_item.child(0)
|
||||
assert tiny_item.text(0) == 'Tiny'
|
||||
tiny_item.setSelected(True)
|
||||
|
||||
download_button = widget.findChild(QPushButton)
|
||||
assert isinstance(download_button, QPushButton)
|
||||
|
||||
assert download_button.text() == 'Download'
|
||||
download_button.click()
|
||||
|
||||
def downloaded_model():
|
||||
assert not download_button.isEnabled()
|
||||
|
||||
_downloaded_item = widget.model_list_widget.topLevelItem(0)
|
||||
assert _downloaded_item.childCount() > 0
|
||||
assert _downloaded_item.child(0).text(0) == 'Tiny'
|
||||
|
||||
_available_item = widget.model_list_widget.topLevelItem(1)
|
||||
assert _available_item.childCount() == 0 or _available_item.child(0).text(0) != 'Tiny'
|
||||
|
||||
assert os.path.isfile(get_whisper_file_path(size=model.whisper_model_size))
|
||||
|
||||
qtbot.wait_until(callback=downloaded_model, timeout=60_000)
|
|
@ -13,6 +13,7 @@ class TestPreferencesDialog:
|
|||
|
||||
tab_widget = dialog.findChild(QTabWidget)
|
||||
assert isinstance(tab_widget, QTabWidget)
|
||||
assert tab_widget.count() == 2
|
||||
assert tab_widget.count() == 3
|
||||
assert tab_widget.tabText(0) == 'General'
|
||||
assert tab_widget.tabText(1) == 'Shortcuts'
|
||||
assert tab_widget.tabText(1) == 'Models'
|
||||
assert tab_widget.tabText(2) == 'Shortcuts'
|
||||
|
|
Loading…
Reference in a new issue