mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 22:55:46 +01:00
Adding custom model size for Whisper.cpp and Faster Whisper (#820)
This commit is contained in:
parent
4d06273305
commit
2eeb03a251
6 changed files with 245 additions and 47 deletions
|
|
@ -20,6 +20,8 @@ import faster_whisper
|
|||
import whisper
|
||||
import huggingface_hub
|
||||
|
||||
from buzz.locale import _
|
||||
|
||||
# Catch exception from whisper.dll not getting loaded.
|
||||
# TODO: Remove flag and try-except when issue with loading
|
||||
# the DLL in some envs is fixed.
|
||||
|
|
@ -46,6 +48,7 @@ class WhisperModelSize(str, enum.Enum):
|
|||
LARGE = "large"
|
||||
LARGEV2 = "large-v2"
|
||||
LARGEV3 = "large-v3"
|
||||
CUSTOM = "custom"
|
||||
|
||||
def to_faster_whisper_model_size(self) -> str:
|
||||
if self == WhisperModelSize.LARGE:
|
||||
|
|
@ -112,9 +115,15 @@ HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
|
|||
|
||||
@dataclass()
|
||||
class TranscriptionModel:
|
||||
model_type: ModelType = ModelType.WHISPER
|
||||
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY
|
||||
hugging_face_model_id: Optional[str] = "openai/whisper-tiny"
|
||||
def __init__(
|
||||
self,
|
||||
model_type: ModelType = ModelType.WHISPER,
|
||||
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY,
|
||||
hugging_face_model_id: Optional[str] = ""
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.whisper_model_size = whisper_model_size
|
||||
self.hugging_face_model_id = hugging_face_model_id
|
||||
|
||||
def __str__(self):
|
||||
match self.model_type:
|
||||
|
|
@ -135,10 +144,16 @@ class TranscriptionModel:
|
|||
return (
|
||||
self.model_type == ModelType.WHISPER
|
||||
or self.model_type == ModelType.WHISPER_CPP
|
||||
or self.model_type == ModelType.FASTER_WHISPER
|
||||
) and self.get_local_model_path() is not None
|
||||
|
||||
def open_file_location(self):
|
||||
model_path = self.get_local_model_path()
|
||||
|
||||
if (self.model_type == ModelType.HUGGING_FACE
|
||||
or self.model_type == ModelType.FASTER_WHISPER):
|
||||
model_path = os.path.dirname(model_path)
|
||||
|
||||
if model_path is None:
|
||||
return
|
||||
self.open_path(path=os.path.dirname(model_path))
|
||||
|
|
@ -160,6 +175,17 @@ class TranscriptionModel:
|
|||
|
||||
def delete_local_file(self):
|
||||
model_path = self.get_local_model_path()
|
||||
|
||||
if (self.model_type == ModelType.HUGGING_FACE
|
||||
or self.model_type == ModelType.FASTER_WHISPER):
|
||||
model_path = os.path.dirname(os.path.dirname(model_path))
|
||||
|
||||
logging.debug("Deleting model directory: %s", model_path)
|
||||
|
||||
shutil.rmtree(model_path, ignore_errors=True)
|
||||
return
|
||||
|
||||
logging.debug("Deleting model file: %s", model_path)
|
||||
os.remove(model_path)
|
||||
|
||||
def get_local_model_path(self) -> Optional[str]:
|
||||
|
|
@ -178,7 +204,7 @@ class TranscriptionModel:
|
|||
if self.model_type == ModelType.FASTER_WHISPER:
|
||||
try:
|
||||
return download_faster_whisper_model(
|
||||
size=self.whisper_model_size.value, local_files_only=True
|
||||
model=self, local_files_only=True
|
||||
)
|
||||
except (ValueError, FileNotFoundError):
|
||||
return None
|
||||
|
|
@ -208,6 +234,7 @@ WHISPER_CPP_MODELS_SHA256 = {
|
|||
"large-v1": "7d99f41a10525d0206bddadd86760181fa920438b6b33237e3118ff6c83bb53d",
|
||||
"large-v2": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487",
|
||||
"large-v3": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2",
|
||||
"custom": None,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -217,6 +244,10 @@ def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
|
|||
|
||||
def get_whisper_file_path(size: WhisperModelSize) -> str:
|
||||
root_dir = os.path.join(model_root_dir, "whisper")
|
||||
|
||||
if size == WhisperModelSize.CUSTOM:
|
||||
return os.path.join(root_dir, "custom")
|
||||
|
||||
url = whisper._MODELS[size.value]
|
||||
return os.path.join(root_dir, os.path.basename(url))
|
||||
|
||||
|
|
@ -286,13 +317,17 @@ def download_from_huggingface(
|
|||
allow_patterns: List[str],
|
||||
progress: pyqtSignal(tuple),
|
||||
):
|
||||
progress.emit((1, 100))
|
||||
progress.emit((0, 100))
|
||||
|
||||
model_root = huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns[1:], # all, but largest
|
||||
cache_dir=model_root_dir
|
||||
)
|
||||
try:
|
||||
model_root = huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns[1:], # all, but largest
|
||||
cache_dir=model_root_dir
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
return ""
|
||||
|
||||
progress.emit((1, 100))
|
||||
|
||||
|
|
@ -302,11 +337,16 @@ def download_from_huggingface(
|
|||
model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, total_file_size)
|
||||
model_download_monitor.start_monitoring()
|
||||
|
||||
huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns[:1], # largest
|
||||
cache_dir=model_root_dir
|
||||
)
|
||||
try:
|
||||
huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns[:1], # largest
|
||||
cache_dir=model_root_dir
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
model_download_monitor.stop_monitoring()
|
||||
return ""
|
||||
|
||||
model_download_monitor.stop_monitoring()
|
||||
|
||||
|
|
@ -314,17 +354,23 @@ def download_from_huggingface(
|
|||
|
||||
|
||||
def download_faster_whisper_model(
|
||||
size: str, local_files_only=False, progress: pyqtSignal(tuple) = None
|
||||
model: TranscriptionModel, local_files_only=False, progress: pyqtSignal(tuple) = None
|
||||
):
|
||||
if size not in faster_whisper.utils._MODELS:
|
||||
size = model.whisper_model_size.to_faster_whisper_model_size()
|
||||
custom_repo_id = model.hugging_face_model_id
|
||||
|
||||
if size != WhisperModelSize.CUSTOM and size not in faster_whisper.utils._MODELS:
|
||||
raise ValueError(
|
||||
"Invalid model size '%s', expected one of: %s"
|
||||
% (size, ", ".join(faster_whisper.utils._MODELS))
|
||||
)
|
||||
|
||||
logging.debug("Downloading Faster Whisper model: %s", size)
|
||||
if size == WhisperModelSize.CUSTOM and custom_repo_id == "":
|
||||
raise ValueError("Custom model id is not provided")
|
||||
|
||||
if size == WhisperModelSize.LARGEV3:
|
||||
if size == WhisperModelSize.CUSTOM:
|
||||
repo_id = custom_repo_id
|
||||
elif size == WhisperModelSize.LARGEV3:
|
||||
repo_id = "Systran/faster-whisper-large-v3"
|
||||
else:
|
||||
repo_id = "guillaumekln/faster-whisper-%s" % size
|
||||
|
|
@ -358,20 +404,28 @@ class ModelDownloader(QRunnable):
|
|||
progress = pyqtSignal(tuple) # (current, total)
|
||||
error = pyqtSignal(str)
|
||||
|
||||
def __init__(self, model: TranscriptionModel):
|
||||
def __init__(self, model: TranscriptionModel, custom_model_url: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
||||
self.signals = self.Signals()
|
||||
self.model = model
|
||||
self.stopped = False
|
||||
self.custom_model_url = custom_model_url
|
||||
|
||||
def run(self) -> None:
|
||||
logging.debug("Downloading model: %s, %s", self.model, self.model.hugging_face_model_id)
|
||||
|
||||
if self.model.model_type == ModelType.WHISPER_CPP:
|
||||
model_name = self.model.whisper_model_size.to_whisper_cpp_model_size()
|
||||
url = huggingface_hub.hf_hub_url(
|
||||
repo_id="ggerganov/whisper.cpp",
|
||||
filename=f"ggml-{model_name}.bin",
|
||||
)
|
||||
|
||||
if self.custom_model_url:
|
||||
url = self.custom_model_url
|
||||
else:
|
||||
url = huggingface_hub.hf_hub_url(
|
||||
repo_id="ggerganov/whisper.cpp",
|
||||
filename=f"ggml-{model_name}.bin",
|
||||
)
|
||||
|
||||
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
|
||||
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
|
||||
return self.download_model_to_path(
|
||||
|
|
@ -388,9 +442,13 @@ class ModelDownloader(QRunnable):
|
|||
|
||||
if self.model.model_type == ModelType.FASTER_WHISPER:
|
||||
model_path = download_faster_whisper_model(
|
||||
size=self.model.whisper_model_size.to_faster_whisper_model_size(),
|
||||
model=self.model,
|
||||
progress=self.signals.progress,
|
||||
)
|
||||
|
||||
if model_path == "":
|
||||
self.signals.error.emit(_("Error"))
|
||||
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
||||
|
|
@ -417,7 +475,7 @@ class ModelDownloader(QRunnable):
|
|||
if downloaded:
|
||||
self.signals.finished.emit(file_path)
|
||||
except requests.RequestException:
|
||||
self.signals.error.emit("A connection error occurred")
|
||||
self.signals.error.emit(_("A connection error occurred"))
|
||||
logging.exception("")
|
||||
except Exception as exc:
|
||||
self.signals.error.emit(str(exc))
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ class Settings:
|
|||
|
||||
DEFAULT_EXPORT_FILE_NAME = "transcriber/default-export-file-name"
|
||||
CUSTOM_OPENAI_BASE_URL = "transcriber/custom-openai-base-url"
|
||||
CUSTOM_FASTER_WHISPER_ID = "transcriber/custom-faster-whisper-id"
|
||||
HUGGINGFACE_MODEL_ID = "transcriber/huggingface-model-id"
|
||||
|
||||
SHORTCUTS = "shortcuts"
|
||||
|
||||
|
|
@ -50,6 +52,36 @@ class Settings:
|
|||
def set_value(self, key: Key, value: typing.Any) -> None:
|
||||
self.settings.setValue(key.value, value)
|
||||
|
||||
def save_custom_model_id(self, model) -> None:
|
||||
from buzz.model_loader import ModelType
|
||||
match model.model_type:
|
||||
case ModelType.FASTER_WHISPER:
|
||||
self.set_value(
|
||||
Settings.Key.CUSTOM_FASTER_WHISPER_ID,
|
||||
model.hugging_face_model_id,
|
||||
)
|
||||
case ModelType.HUGGING_FACE:
|
||||
self.set_value(
|
||||
Settings.Key.HUGGINGFACE_MODEL_ID,
|
||||
model.hugging_face_model_id,
|
||||
)
|
||||
|
||||
def load_custom_model_id(self, model) -> str:
|
||||
from buzz.model_loader import ModelType
|
||||
match model.model_type:
|
||||
case ModelType.FASTER_WHISPER:
|
||||
return self.value(
|
||||
Settings.Key.CUSTOM_FASTER_WHISPER_ID,
|
||||
"",
|
||||
)
|
||||
case ModelType.HUGGING_FACE:
|
||||
return self.value(
|
||||
Settings.Key.HUGGINGFACE_MODEL_ID,
|
||||
"",
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
def value(
|
||||
self,
|
||||
key: Key,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import tqdm
|
|||
from PyQt6.QtCore import QObject
|
||||
|
||||
from buzz.conn import pipe_stderr
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.model_loader import ModelType, WhisperModelSize
|
||||
from buzz.transformers_whisper import TransformersWhisper
|
||||
from buzz.transcriber.file_transcriber import FileTranscriber
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment
|
||||
|
|
@ -131,8 +131,13 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
|
||||
@classmethod
|
||||
def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]:
|
||||
if task.transcription_options.model.whisper_model_size == WhisperModelSize.CUSTOM:
|
||||
model_size_or_path = task.transcription_options.model.hugging_face_model_id
|
||||
else:
|
||||
model_size_or_path = task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()
|
||||
|
||||
model = faster_whisper.WhisperModel(
|
||||
model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()
|
||||
model_size_or_path=model_size_or_path
|
||||
)
|
||||
whisper_segments, info = model.transcribe(
|
||||
audio=task.file_path,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import Qt, QThreadPool
|
||||
|
|
@ -18,8 +19,13 @@ from buzz.model_loader import (
|
|||
TranscriptionModel,
|
||||
ModelDownloader,
|
||||
)
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
from buzz.widgets.transcriber.hugging_face_search_line_edit import (
|
||||
HuggingFaceSearchLineEdit,
|
||||
)
|
||||
|
||||
|
||||
class ModelsPreferencesWidget(QWidget):
|
||||
|
|
@ -32,6 +38,7 @@ class ModelsPreferencesWidget(QWidget):
|
|||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.settings = Settings()
|
||||
self.model_downloader: Optional[ModelDownloader] = None
|
||||
|
||||
model_types = [
|
||||
|
|
@ -67,6 +74,20 @@ class ModelsPreferencesWidget(QWidget):
|
|||
|
||||
buttons_layout = QHBoxLayout()
|
||||
|
||||
self.custom_model_id_input = HuggingFaceSearchLineEdit()
|
||||
self.custom_model_id_input.setObjectName("ModelIdInput")
|
||||
|
||||
self.custom_model_id_input.setPlaceholderText(_("Huggingface ID of a Faster whisper model"))
|
||||
self.custom_model_id_input.textChanged.connect(self.on_custom_model_id_input_changed)
|
||||
layout.addRow("", self.custom_model_id_input)
|
||||
self.custom_model_id_input.hide()
|
||||
|
||||
self.custom_model_link_input = LineEdit()
|
||||
self.custom_model_link_input.setObjectName("ModelLinkInput")
|
||||
self.custom_model_link_input.textChanged.connect(self.on_custom_model_link_input_changed)
|
||||
layout.addRow("", self.custom_model_link_input)
|
||||
self.custom_model_link_input.hide()
|
||||
|
||||
self.download_button = QPushButton(_("Download"))
|
||||
self.download_button.setObjectName("DownloadButton")
|
||||
self.download_button.clicked.connect(self.on_download_button_clicked)
|
||||
|
|
@ -100,17 +121,11 @@ class ModelsPreferencesWidget(QWidget):
|
|||
self.model.whisper_model_size = item_data
|
||||
self.reset()
|
||||
|
||||
@staticmethod
|
||||
def can_delete_model(model: TranscriptionModel):
|
||||
return (
|
||||
model.model_type == ModelType.WHISPER
|
||||
or model.model_type == ModelType.WHISPER_CPP
|
||||
) and model.get_local_model_path() is not None
|
||||
|
||||
def reset(self):
|
||||
# reset buttons
|
||||
path = self.model.get_local_model_path()
|
||||
self.download_button.setVisible(path is None)
|
||||
self.download_button.setEnabled(self.model.whisper_model_size != WhisperModelSize.CUSTOM)
|
||||
self.delete_button.setVisible(self.model.is_deletable())
|
||||
self.show_file_location_button.setVisible(self.model.is_deletable())
|
||||
|
||||
|
|
@ -129,12 +144,45 @@ class ModelsPreferencesWidget(QWidget):
|
|||
self.model_list_widget.setHeaderHidden(True)
|
||||
self.model_list_widget.setAlternatingRowColors(True)
|
||||
|
||||
self.model.hugging_face_model_id = self.settings.load_custom_model_id(self.model)
|
||||
self.custom_model_id_input.setText(self.model.hugging_face_model_id)
|
||||
|
||||
if (self.model.whisper_model_size == WhisperModelSize.CUSTOM
|
||||
and self.model.model_type == ModelType.FASTER_WHISPER):
|
||||
self.custom_model_id_input.show()
|
||||
self.download_button.setEnabled(
|
||||
self.model.hugging_face_model_id != ""
|
||||
)
|
||||
else:
|
||||
self.custom_model_id_input.hide()
|
||||
|
||||
if self.model.model_type == ModelType.WHISPER_CPP:
|
||||
self.custom_model_link_input.setPlaceholderText(
|
||||
_("Download link to Whisper.cpp ggml model file")
|
||||
)
|
||||
|
||||
if (self.model.whisper_model_size == WhisperModelSize.CUSTOM
|
||||
and self.model.model_type == ModelType.WHISPER_CPP
|
||||
and path is None):
|
||||
self.custom_model_link_input.show()
|
||||
self.download_button.setEnabled(
|
||||
self.custom_model_link_input.text() != "")
|
||||
else:
|
||||
self.custom_model_link_input.hide()
|
||||
|
||||
if self.model is None:
|
||||
return
|
||||
|
||||
for model_size in WhisperModelSize:
|
||||
# Skip custom size for OpenAI Whisper
|
||||
if (self.model.model_type == ModelType.WHISPER and
|
||||
model_size == WhisperModelSize.CUSTOM):
|
||||
continue
|
||||
|
||||
model = TranscriptionModel(
|
||||
model_type=self.model.model_type, whisper_model_size=model_size
|
||||
model_type=self.model.model_type,
|
||||
whisper_model_size=WhisperModelSize(model_size),
|
||||
hugging_face_model_id=self.model.hugging_face_model_id,
|
||||
)
|
||||
model_path = model.get_local_model_path()
|
||||
parent = downloaded_item if model_path is not None else available_item
|
||||
|
|
@ -149,6 +197,16 @@ class ModelsPreferencesWidget(QWidget):
|
|||
self.model.model_type = model_type
|
||||
self.reset()
|
||||
|
||||
def on_custom_model_id_input_changed(self, text):
|
||||
self.model.hugging_face_model_id = text
|
||||
self.settings.save_custom_model_id(self.model)
|
||||
self.download_button.setEnabled(
|
||||
self.model.hugging_face_model_id != ""
|
||||
)
|
||||
|
||||
def on_custom_model_link_input_changed(self, text):
|
||||
self.download_button.setEnabled(text != "")
|
||||
|
||||
def on_download_button_clicked(self):
|
||||
self.progress_dialog = ModelDownloadProgressDialog(
|
||||
model_type=self.model.model_type,
|
||||
|
|
@ -159,7 +217,15 @@ class ModelsPreferencesWidget(QWidget):
|
|||
|
||||
self.download_button.setEnabled(False)
|
||||
|
||||
self.model_downloader = ModelDownloader(model=self.model)
|
||||
if (self.model.whisper_model_size == WhisperModelSize.CUSTOM and
|
||||
self.model.model_type == ModelType.WHISPER_CPP):
|
||||
self.model_downloader = ModelDownloader(
|
||||
model=self.model,
|
||||
custom_model_url=self.custom_model_link_input.text()
|
||||
)
|
||||
else:
|
||||
self.model_downloader = ModelDownloader(model=self.model)
|
||||
|
||||
self.model_downloader.signals.finished.connect(self.on_download_completed)
|
||||
self.model_downloader.signals.progress.connect(self.on_download_progress)
|
||||
self.model_downloader.signals.error.connect(self.on_download_error)
|
||||
|
|
@ -185,10 +251,12 @@ class ModelsPreferencesWidget(QWidget):
|
|||
|
||||
def on_download_error(self, error: str):
|
||||
self.progress_dialog.cancel()
|
||||
self.progress_dialog.close()
|
||||
self.progress_dialog = None
|
||||
self.download_button.setEnabled(True)
|
||||
self.reset()
|
||||
QMessageBox.warning(self, _("Error"), f"{_('Download failed')}: {error}")
|
||||
download_failed_label = _('Download failed')
|
||||
QMessageBox.warning(self, _("Error"), f"{download_failed_label}: {error}")
|
||||
|
||||
def on_download_progress(self, progress: tuple):
|
||||
self.progress_dialog.set_value(float(progress[0]) / progress[1])
|
||||
|
|
|
|||
|
|
@ -14,12 +14,11 @@ from PyQt6.QtCore import (
|
|||
QEvent,
|
||||
)
|
||||
from PyQt6.QtGui import QKeyEvent
|
||||
from PyQt6.QtCore import QSettings
|
||||
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
|
||||
from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
from buzz.settings.settings import APP_NAME
|
||||
|
||||
|
||||
# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py
|
||||
|
|
@ -29,11 +28,12 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
default_value: str,
|
||||
default_value: str = "",
|
||||
network_access_manager: Optional[QNetworkAccessManager] = None,
|
||||
parent: Optional[QWidget] = None,
|
||||
):
|
||||
super().__init__(default_value, parent)
|
||||
self.setPlaceholderText(_("Huggingface ID of a model"))
|
||||
|
||||
self.setMinimumWidth(150)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from PyQt6.QtCore import pyqtSignal
|
|||
from PyQt6.QtWidgets import QGroupBox, QWidget, QFormLayout, QComboBox
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.model_loader import ModelType, WhisperModelSize
|
||||
from buzz.transcriber.transcriber import TranscriptionOptions, Task
|
||||
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
|
||||
|
|
@ -29,6 +30,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
parent: Optional[QWidget] = None,
|
||||
):
|
||||
super().__init__(title="", parent=parent)
|
||||
self.settings = Settings()
|
||||
self.transcription_options = default_transcription_options
|
||||
|
||||
self.form_layout = QFormLayout(self)
|
||||
|
|
@ -49,12 +51,8 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
|
||||
self.whisper_model_size_combo_box = QComboBox(self)
|
||||
self.whisper_model_size_combo_box.addItems(
|
||||
[size.value.title() for size in WhisperModelSize]
|
||||
[size.value.title() for size in WhisperModelSize if size != WhisperModelSize.CUSTOM]
|
||||
)
|
||||
if default_transcription_options.model.whisper_model_size is not None:
|
||||
self.whisper_model_size_combo_box.setCurrentText(
|
||||
default_transcription_options.model.whisper_model_size.value.title()
|
||||
)
|
||||
self.whisper_model_size_combo_box.currentTextChanged.connect(
|
||||
self.on_whisper_model_size_changed
|
||||
)
|
||||
|
|
@ -72,6 +70,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.hugging_face_search_line_edit.model_selected.connect(
|
||||
self.on_hugging_face_model_changed
|
||||
)
|
||||
self.hugging_face_search_line_edit.setVisible(False)
|
||||
|
||||
self.tasks_combo_box = TasksComboBox(
|
||||
default_task=self.transcription_options.task, parent=self
|
||||
|
|
@ -122,9 +121,40 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
|
||||
def reset_visible_rows(self):
|
||||
model_type = self.transcription_options.model.model_type
|
||||
whisper_model_size = self.transcription_options.model.whisper_model_size
|
||||
|
||||
if (model_type == ModelType.HUGGING_FACE
|
||||
or (whisper_model_size == WhisperModelSize.CUSTOM
|
||||
and model_type == ModelType.FASTER_WHISPER)):
|
||||
self.transcription_options.model.hugging_face_model_id = (
|
||||
self.settings.load_custom_model_id(self.transcription_options.model))
|
||||
self.hugging_face_search_line_edit.setText(
|
||||
self.transcription_options.model.hugging_face_model_id)
|
||||
|
||||
self.form_layout.setRowVisible(
|
||||
self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE
|
||||
self.hugging_face_search_line_edit,
|
||||
(model_type == ModelType.HUGGING_FACE)
|
||||
or (model_type == ModelType.FASTER_WHISPER
|
||||
and whisper_model_size == WhisperModelSize.CUSTOM),
|
||||
)
|
||||
|
||||
custom_model_index = (self.whisper_model_size_combo_box
|
||||
.findText(WhisperModelSize.CUSTOM.value.title()))
|
||||
if (model_type == ModelType.WHISPER
|
||||
and whisper_model_size == WhisperModelSize.CUSTOM
|
||||
and custom_model_index != -1):
|
||||
self.whisper_model_size_combo_box.removeItem(custom_model_index)
|
||||
|
||||
if ((model_type == ModelType.WHISPER_CPP or model_type == ModelType.FASTER_WHISPER)
|
||||
and custom_model_index == -1):
|
||||
self.whisper_model_size_combo_box.addItem(
|
||||
WhisperModelSize.CUSTOM.value.title()
|
||||
)
|
||||
|
||||
self.whisper_model_size_combo_box.setCurrentText(
|
||||
self.transcription_options.model.whisper_model_size.value.title()
|
||||
)
|
||||
|
||||
self.form_layout.setRowVisible(
|
||||
self.whisper_model_size_combo_box,
|
||||
(model_type == ModelType.WHISPER)
|
||||
|
|
@ -146,8 +176,13 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
def on_whisper_model_size_changed(self, text: str):
|
||||
model_size = WhisperModelSize(text.lower())
|
||||
self.transcription_options.model.whisper_model_size = model_size
|
||||
|
||||
self.reset_visible_rows()
|
||||
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_hugging_face_model_changed(self, model: str):
|
||||
self.transcription_options.model.hugging_face_model_id = model
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
self.settings.save_custom_model_id(self.transcription_options.model)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue