Adding custom model size for Whisper.cpp and Faster Whisper (#820)

This commit is contained in:
Raivis Dejus 2024-07-02 20:51:39 +03:00 committed by GitHub
commit 2eeb03a251
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 245 additions and 47 deletions

View file

@ -20,6 +20,8 @@ import faster_whisper
import whisper
import huggingface_hub
from buzz.locale import _
# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
# the DLL in some envs is fixed.
@ -46,6 +48,7 @@ class WhisperModelSize(str, enum.Enum):
LARGE = "large"
LARGEV2 = "large-v2"
LARGEV3 = "large-v3"
CUSTOM = "custom"
def to_faster_whisper_model_size(self) -> str:
if self == WhisperModelSize.LARGE:
@ -112,9 +115,15 @@ HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
@dataclass()
class TranscriptionModel:
model_type: ModelType = ModelType.WHISPER
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY
hugging_face_model_id: Optional[str] = "openai/whisper-tiny"
def __init__(
self,
model_type: ModelType = ModelType.WHISPER,
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY,
hugging_face_model_id: Optional[str] = ""
):
self.model_type = model_type
self.whisper_model_size = whisper_model_size
self.hugging_face_model_id = hugging_face_model_id
def __str__(self):
match self.model_type:
@ -135,10 +144,16 @@ class TranscriptionModel:
return (
self.model_type == ModelType.WHISPER
or self.model_type == ModelType.WHISPER_CPP
or self.model_type == ModelType.FASTER_WHISPER
) and self.get_local_model_path() is not None
def open_file_location(self):
model_path = self.get_local_model_path()
if (self.model_type == ModelType.HUGGING_FACE
or self.model_type == ModelType.FASTER_WHISPER):
model_path = os.path.dirname(model_path)
if model_path is None:
return
self.open_path(path=os.path.dirname(model_path))
@ -160,6 +175,17 @@ class TranscriptionModel:
def delete_local_file(self):
model_path = self.get_local_model_path()
if (self.model_type == ModelType.HUGGING_FACE
or self.model_type == ModelType.FASTER_WHISPER):
model_path = os.path.dirname(os.path.dirname(model_path))
logging.debug("Deleting model directory: %s", model_path)
shutil.rmtree(model_path, ignore_errors=True)
return
logging.debug("Deleting model file: %s", model_path)
os.remove(model_path)
def get_local_model_path(self) -> Optional[str]:
@ -178,7 +204,7 @@ class TranscriptionModel:
if self.model_type == ModelType.FASTER_WHISPER:
try:
return download_faster_whisper_model(
size=self.whisper_model_size.value, local_files_only=True
model=self, local_files_only=True
)
except (ValueError, FileNotFoundError):
return None
@ -208,6 +234,7 @@ WHISPER_CPP_MODELS_SHA256 = {
"large-v1": "7d99f41a10525d0206bddadd86760181fa920438b6b33237e3118ff6c83bb53d",
"large-v2": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487",
"large-v3": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2",
"custom": None,
}
@ -217,6 +244,10 @@ def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
def get_whisper_file_path(size: WhisperModelSize) -> str:
root_dir = os.path.join(model_root_dir, "whisper")
if size == WhisperModelSize.CUSTOM:
return os.path.join(root_dir, "custom")
url = whisper._MODELS[size.value]
return os.path.join(root_dir, os.path.basename(url))
@ -286,13 +317,17 @@ def download_from_huggingface(
allow_patterns: List[str],
progress: pyqtSignal(tuple),
):
progress.emit((1, 100))
progress.emit((0, 100))
model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[1:], # all, but largest
cache_dir=model_root_dir
)
try:
model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[1:], # all, but largest
cache_dir=model_root_dir
)
except Exception as exc:
logging.exception(exc)
return ""
progress.emit((1, 100))
@ -302,11 +337,16 @@ def download_from_huggingface(
model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, total_file_size)
model_download_monitor.start_monitoring()
huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:1], # largest
cache_dir=model_root_dir
)
try:
huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:1], # largest
cache_dir=model_root_dir
)
except Exception as exc:
logging.exception(exc)
model_download_monitor.stop_monitoring()
return ""
model_download_monitor.stop_monitoring()
@ -314,17 +354,23 @@ def download_from_huggingface(
def download_faster_whisper_model(
size: str, local_files_only=False, progress: pyqtSignal(tuple) = None
model: TranscriptionModel, local_files_only=False, progress: pyqtSignal(tuple) = None
):
if size not in faster_whisper.utils._MODELS:
size = model.whisper_model_size.to_faster_whisper_model_size()
custom_repo_id = model.hugging_face_model_id
if size != WhisperModelSize.CUSTOM and size not in faster_whisper.utils._MODELS:
raise ValueError(
"Invalid model size '%s', expected one of: %s"
% (size, ", ".join(faster_whisper.utils._MODELS))
)
logging.debug("Downloading Faster Whisper model: %s", size)
if size == WhisperModelSize.CUSTOM and custom_repo_id == "":
raise ValueError("Custom model id is not provided")
if size == WhisperModelSize.LARGEV3:
if size == WhisperModelSize.CUSTOM:
repo_id = custom_repo_id
elif size == WhisperModelSize.LARGEV3:
repo_id = "Systran/faster-whisper-large-v3"
else:
repo_id = "guillaumekln/faster-whisper-%s" % size
@ -358,20 +404,28 @@ class ModelDownloader(QRunnable):
progress = pyqtSignal(tuple) # (current, total)
error = pyqtSignal(str)
def __init__(self, model: TranscriptionModel):
def __init__(self, model: TranscriptionModel, custom_model_url: Optional[str] = None):
super().__init__()
self.signals = self.Signals()
self.model = model
self.stopped = False
self.custom_model_url = custom_model_url
def run(self) -> None:
logging.debug("Downloading model: %s, %s", self.model, self.model.hugging_face_model_id)
if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.to_whisper_cpp_model_size()
url = huggingface_hub.hf_hub_url(
repo_id="ggerganov/whisper.cpp",
filename=f"ggml-{model_name}.bin",
)
if self.custom_model_url:
url = self.custom_model_url
else:
url = huggingface_hub.hf_hub_url(
repo_id="ggerganov/whisper.cpp",
filename=f"ggml-{model_name}.bin",
)
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
return self.download_model_to_path(
@ -388,9 +442,13 @@ class ModelDownloader(QRunnable):
if self.model.model_type == ModelType.FASTER_WHISPER:
model_path = download_faster_whisper_model(
size=self.model.whisper_model_size.to_faster_whisper_model_size(),
model=self.model,
progress=self.signals.progress,
)
if model_path == "":
self.signals.error.emit(_("Error"))
self.signals.finished.emit(model_path)
return
@ -417,7 +475,7 @@ class ModelDownloader(QRunnable):
if downloaded:
self.signals.finished.emit(file_path)
except requests.RequestException:
self.signals.error.emit("A connection error occurred")
self.signals.error.emit(_("A connection error occurred"))
logging.exception("")
except Exception as exc:
self.signals.error.emit(str(exc))

View file

@ -38,6 +38,8 @@ class Settings:
DEFAULT_EXPORT_FILE_NAME = "transcriber/default-export-file-name"
CUSTOM_OPENAI_BASE_URL = "transcriber/custom-openai-base-url"
CUSTOM_FASTER_WHISPER_ID = "transcriber/custom-faster-whisper-id"
HUGGINGFACE_MODEL_ID = "transcriber/huggingface-model-id"
SHORTCUTS = "shortcuts"
@ -50,6 +52,36 @@ class Settings:
def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)
def save_custom_model_id(self, model) -> None:
from buzz.model_loader import ModelType
match model.model_type:
case ModelType.FASTER_WHISPER:
self.set_value(
Settings.Key.CUSTOM_FASTER_WHISPER_ID,
model.hugging_face_model_id,
)
case ModelType.HUGGING_FACE:
self.set_value(
Settings.Key.HUGGINGFACE_MODEL_ID,
model.hugging_face_model_id,
)
def load_custom_model_id(self, model) -> str:
from buzz.model_loader import ModelType
match model.model_type:
case ModelType.FASTER_WHISPER:
return self.value(
Settings.Key.CUSTOM_FASTER_WHISPER_ID,
"",
)
case ModelType.HUGGING_FACE:
return self.value(
Settings.Key.HUGGINGFACE_MODEL_ID,
"",
)
return ""
def value(
self,
key: Key,

View file

@ -13,7 +13,7 @@ import tqdm
from PyQt6.QtCore import QObject
from buzz.conn import pipe_stderr
from buzz.model_loader import ModelType
from buzz.model_loader import ModelType, WhisperModelSize
from buzz.transformers_whisper import TransformersWhisper
from buzz.transcriber.file_transcriber import FileTranscriber
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment
@ -131,8 +131,13 @@ class WhisperFileTranscriber(FileTranscriber):
@classmethod
def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]:
if task.transcription_options.model.whisper_model_size == WhisperModelSize.CUSTOM:
model_size_or_path = task.transcription_options.model.hugging_face_model_id
else:
model_size_or_path = task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()
model = faster_whisper.WhisperModel(
model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()
model_size_or_path=model_size_or_path
)
whisper_segments, info = model.transcribe(
audio=task.file_path,

View file

@ -1,3 +1,4 @@
import logging
from typing import Optional
from PyQt6.QtCore import Qt, QThreadPool
@ -18,8 +19,13 @@ from buzz.model_loader import (
TranscriptionModel,
ModelDownloader,
)
from buzz.settings.settings import Settings
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
from buzz.widgets.line_edit import LineEdit
from buzz.widgets.transcriber.hugging_face_search_line_edit import (
HuggingFaceSearchLineEdit,
)
class ModelsPreferencesWidget(QWidget):
@ -32,6 +38,7 @@ class ModelsPreferencesWidget(QWidget):
):
super().__init__(parent)
self.settings = Settings()
self.model_downloader: Optional[ModelDownloader] = None
model_types = [
@ -67,6 +74,20 @@ class ModelsPreferencesWidget(QWidget):
buttons_layout = QHBoxLayout()
self.custom_model_id_input = HuggingFaceSearchLineEdit()
self.custom_model_id_input.setObjectName("ModelIdInput")
self.custom_model_id_input.setPlaceholderText(_("Huggingface ID of a Faster whisper model"))
self.custom_model_id_input.textChanged.connect(self.on_custom_model_id_input_changed)
layout.addRow("", self.custom_model_id_input)
self.custom_model_id_input.hide()
self.custom_model_link_input = LineEdit()
self.custom_model_link_input.setObjectName("ModelLinkInput")
self.custom_model_link_input.textChanged.connect(self.on_custom_model_link_input_changed)
layout.addRow("", self.custom_model_link_input)
self.custom_model_link_input.hide()
self.download_button = QPushButton(_("Download"))
self.download_button.setObjectName("DownloadButton")
self.download_button.clicked.connect(self.on_download_button_clicked)
@ -100,17 +121,11 @@ class ModelsPreferencesWidget(QWidget):
self.model.whisper_model_size = item_data
self.reset()
@staticmethod
def can_delete_model(model: TranscriptionModel):
return (
model.model_type == ModelType.WHISPER
or model.model_type == ModelType.WHISPER_CPP
) and model.get_local_model_path() is not None
def reset(self):
# reset buttons
path = self.model.get_local_model_path()
self.download_button.setVisible(path is None)
self.download_button.setEnabled(self.model.whisper_model_size != WhisperModelSize.CUSTOM)
self.delete_button.setVisible(self.model.is_deletable())
self.show_file_location_button.setVisible(self.model.is_deletable())
@ -129,12 +144,45 @@ class ModelsPreferencesWidget(QWidget):
self.model_list_widget.setHeaderHidden(True)
self.model_list_widget.setAlternatingRowColors(True)
self.model.hugging_face_model_id = self.settings.load_custom_model_id(self.model)
self.custom_model_id_input.setText(self.model.hugging_face_model_id)
if (self.model.whisper_model_size == WhisperModelSize.CUSTOM
and self.model.model_type == ModelType.FASTER_WHISPER):
self.custom_model_id_input.show()
self.download_button.setEnabled(
self.model.hugging_face_model_id != ""
)
else:
self.custom_model_id_input.hide()
if self.model.model_type == ModelType.WHISPER_CPP:
self.custom_model_link_input.setPlaceholderText(
_("Download link to Whisper.cpp ggml model file")
)
if (self.model.whisper_model_size == WhisperModelSize.CUSTOM
and self.model.model_type == ModelType.WHISPER_CPP
and path is None):
self.custom_model_link_input.show()
self.download_button.setEnabled(
self.custom_model_link_input.text() != "")
else:
self.custom_model_link_input.hide()
if self.model is None:
return
for model_size in WhisperModelSize:
# Skip custom size for OpenAI Whisper
if (self.model.model_type == ModelType.WHISPER and
model_size == WhisperModelSize.CUSTOM):
continue
model = TranscriptionModel(
model_type=self.model.model_type, whisper_model_size=model_size
model_type=self.model.model_type,
whisper_model_size=WhisperModelSize(model_size),
hugging_face_model_id=self.model.hugging_face_model_id,
)
model_path = model.get_local_model_path()
parent = downloaded_item if model_path is not None else available_item
@ -149,6 +197,16 @@ class ModelsPreferencesWidget(QWidget):
self.model.model_type = model_type
self.reset()
def on_custom_model_id_input_changed(self, text):
self.model.hugging_face_model_id = text
self.settings.save_custom_model_id(self.model)
self.download_button.setEnabled(
self.model.hugging_face_model_id != ""
)
def on_custom_model_link_input_changed(self, text):
self.download_button.setEnabled(text != "")
def on_download_button_clicked(self):
self.progress_dialog = ModelDownloadProgressDialog(
model_type=self.model.model_type,
@ -159,7 +217,15 @@ class ModelsPreferencesWidget(QWidget):
self.download_button.setEnabled(False)
self.model_downloader = ModelDownloader(model=self.model)
if (self.model.whisper_model_size == WhisperModelSize.CUSTOM and
self.model.model_type == ModelType.WHISPER_CPP):
self.model_downloader = ModelDownloader(
model=self.model,
custom_model_url=self.custom_model_link_input.text()
)
else:
self.model_downloader = ModelDownloader(model=self.model)
self.model_downloader.signals.finished.connect(self.on_download_completed)
self.model_downloader.signals.progress.connect(self.on_download_progress)
self.model_downloader.signals.error.connect(self.on_download_error)
@ -185,10 +251,12 @@ class ModelsPreferencesWidget(QWidget):
def on_download_error(self, error: str):
self.progress_dialog.cancel()
self.progress_dialog.close()
self.progress_dialog = None
self.download_button.setEnabled(True)
self.reset()
QMessageBox.warning(self, _("Error"), f"{_('Download failed')}: {error}")
download_failed_label = _('Download failed')
QMessageBox.warning(self, _("Error"), f"{download_failed_label}: {error}")
def on_download_progress(self, progress: tuple):
self.progress_dialog.set_value(float(progress[0]) / progress[1])

View file

@ -14,12 +14,11 @@ from PyQt6.QtCore import (
QEvent,
)
from PyQt6.QtGui import QKeyEvent
from PyQt6.QtCore import QSettings
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem
from buzz.locale import _
from buzz.widgets.line_edit import LineEdit
from buzz.settings.settings import APP_NAME
# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py
@ -29,11 +28,12 @@ class HuggingFaceSearchLineEdit(LineEdit):
def __init__(
self,
default_value: str,
default_value: str = "",
network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None,
):
super().__init__(default_value, parent)
self.setPlaceholderText(_("Huggingface ID of a model"))
self.setMinimumWidth(150)

View file

@ -5,6 +5,7 @@ from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QGroupBox, QWidget, QFormLayout, QComboBox
from buzz.locale import _
from buzz.settings.settings import Settings
from buzz.model_loader import ModelType, WhisperModelSize
from buzz.transcriber.transcriber import TranscriptionOptions, Task
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
@ -29,6 +30,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
parent: Optional[QWidget] = None,
):
super().__init__(title="", parent=parent)
self.settings = Settings()
self.transcription_options = default_transcription_options
self.form_layout = QFormLayout(self)
@ -49,12 +51,8 @@ class TranscriptionOptionsGroupBox(QGroupBox):
self.whisper_model_size_combo_box = QComboBox(self)
self.whisper_model_size_combo_box.addItems(
[size.value.title() for size in WhisperModelSize]
[size.value.title() for size in WhisperModelSize if size != WhisperModelSize.CUSTOM]
)
if default_transcription_options.model.whisper_model_size is not None:
self.whisper_model_size_combo_box.setCurrentText(
default_transcription_options.model.whisper_model_size.value.title()
)
self.whisper_model_size_combo_box.currentTextChanged.connect(
self.on_whisper_model_size_changed
)
@ -72,6 +70,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
self.hugging_face_search_line_edit.model_selected.connect(
self.on_hugging_face_model_changed
)
self.hugging_face_search_line_edit.setVisible(False)
self.tasks_combo_box = TasksComboBox(
default_task=self.transcription_options.task, parent=self
@ -122,9 +121,40 @@ class TranscriptionOptionsGroupBox(QGroupBox):
def reset_visible_rows(self):
model_type = self.transcription_options.model.model_type
whisper_model_size = self.transcription_options.model.whisper_model_size
if (model_type == ModelType.HUGGING_FACE
or (whisper_model_size == WhisperModelSize.CUSTOM
and model_type == ModelType.FASTER_WHISPER)):
self.transcription_options.model.hugging_face_model_id = (
self.settings.load_custom_model_id(self.transcription_options.model))
self.hugging_face_search_line_edit.setText(
self.transcription_options.model.hugging_face_model_id)
self.form_layout.setRowVisible(
self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE
self.hugging_face_search_line_edit,
(model_type == ModelType.HUGGING_FACE)
or (model_type == ModelType.FASTER_WHISPER
and whisper_model_size == WhisperModelSize.CUSTOM),
)
custom_model_index = (self.whisper_model_size_combo_box
.findText(WhisperModelSize.CUSTOM.value.title()))
if (model_type == ModelType.WHISPER
and whisper_model_size == WhisperModelSize.CUSTOM
and custom_model_index != -1):
self.whisper_model_size_combo_box.removeItem(custom_model_index)
if ((model_type == ModelType.WHISPER_CPP or model_type == ModelType.FASTER_WHISPER)
and custom_model_index == -1):
self.whisper_model_size_combo_box.addItem(
WhisperModelSize.CUSTOM.value.title()
)
self.whisper_model_size_combo_box.setCurrentText(
self.transcription_options.model.whisper_model_size.value.title()
)
self.form_layout.setRowVisible(
self.whisper_model_size_combo_box,
(model_type == ModelType.WHISPER)
@ -146,8 +176,13 @@ class TranscriptionOptionsGroupBox(QGroupBox):
def on_whisper_model_size_changed(self, text: str):
model_size = WhisperModelSize(text.lower())
self.transcription_options.model.whisper_model_size = model_size
self.reset_visible_rows()
self.transcription_options_changed.emit(self.transcription_options)
def on_hugging_face_model_changed(self, model: str):
self.transcription_options.model.hugging_face_model_id = model
self.transcription_options_changed.emit(self.transcription_options)
self.settings.save_custom_model_id(self.transcription_options.model)