mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 22:55:46 +01:00
599 lines
20 KiB
Python
599 lines
20 KiB
Python
import enum
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import time
|
|
import threading
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import warnings
|
|
import platform
|
|
import requests
|
|
import whisper
|
|
import huggingface_hub
|
|
import zipfile
|
|
from dataclasses import dataclass
|
|
from typing import Optional, List
|
|
|
|
from PyQt6.QtCore import QObject, pyqtSignal, QRunnable
|
|
from platformdirs import user_cache_dir
|
|
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
|
|
from buzz.locale import _
|
|
|
|
|
|
model_root_dir = user_cache_dir("Buzz")
|
|
model_root_dir = os.path.join(model_root_dir, "models")
|
|
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
|
|
os.makedirs(model_root_dir, exist_ok=True)
|
|
|
|
logging.debug("Model root directory: %s", model_root_dir)
|
|
|
|
|
|
class WhisperModelSize(str, enum.Enum):
|
|
TINY = "tiny"
|
|
TINYEN = "tiny.en"
|
|
BASE = "base"
|
|
BASEEN = "base.en"
|
|
SMALL = "small"
|
|
SMALLEN = "small.en"
|
|
MEDIUM = "medium"
|
|
MEDIUMEN = "medium.en"
|
|
LARGE = "large"
|
|
LARGEV2 = "large-v2"
|
|
LARGEV3 = "large-v3"
|
|
LARGEV3TURBO = "large-v3-turbo"
|
|
CUSTOM = "custom"
|
|
LUMII = "lumii"
|
|
|
|
def to_faster_whisper_model_size(self) -> str:
|
|
if self == WhisperModelSize.LARGE:
|
|
return "large-v1"
|
|
return self.value
|
|
|
|
def to_whisper_cpp_model_size(self) -> str:
|
|
if self == WhisperModelSize.LARGE:
|
|
return "large-v1"
|
|
return self.value
|
|
|
|
def __str__(self):
|
|
return self.value.capitalize()
|
|
|
|
|
|
class ModelType(enum.Enum):
|
|
WHISPER = "Whisper"
|
|
WHISPER_CPP = "Whisper.cpp"
|
|
HUGGING_FACE = "Hugging Face"
|
|
FASTER_WHISPER = "Faster Whisper"
|
|
OPEN_AI_WHISPER_API = "OpenAI Whisper API"
|
|
|
|
@property
|
|
def supports_initial_prompt(self):
|
|
return self in (
|
|
ModelType.WHISPER,
|
|
ModelType.WHISPER_CPP,
|
|
ModelType.OPEN_AI_WHISPER_API,
|
|
ModelType.FASTER_WHISPER,
|
|
)
|
|
|
|
def is_available(self):
|
|
if (
|
|
# Hide Faster Whisper option on macOS x86_64
|
|
# See: https://github.com/SYSTRAN/faster-whisper/issues/541
|
|
(self == ModelType.FASTER_WHISPER
|
|
and platform.system() == "Darwin" and platform.machine() == "x86_64")
|
|
):
|
|
return False
|
|
return True
|
|
|
|
def is_manually_downloadable(self):
|
|
return self in (
|
|
ModelType.WHISPER,
|
|
ModelType.WHISPER_CPP,
|
|
ModelType.FASTER_WHISPER,
|
|
)
|
|
|
|
|
|
HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
|
|
"model.safetensors", # largest by size first
|
|
"pytorch_model.bin",
|
|
"model-00001-of-00002.safetensors",
|
|
"model-00002-of-00002.safetensors",
|
|
"model.safetensors.index.json",
|
|
"added_tokens.json",
|
|
"config.json",
|
|
"generation_config.json",
|
|
"merges.txt",
|
|
"normalizer.json",
|
|
"preprocessor_config.json",
|
|
"special_tokens_map.json",
|
|
"tokenizer.json",
|
|
"tokenizer_config.json",
|
|
"vocab.json",
|
|
]
|
|
|
|
|
|
@dataclass()
|
|
class TranscriptionModel:
|
|
def __init__(
|
|
self,
|
|
model_type: ModelType = ModelType.WHISPER,
|
|
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY,
|
|
hugging_face_model_id: Optional[str] = ""
|
|
):
|
|
self.model_type = model_type
|
|
self.whisper_model_size = whisper_model_size
|
|
self.hugging_face_model_id = hugging_face_model_id
|
|
|
|
def __str__(self):
|
|
match self.model_type:
|
|
case ModelType.WHISPER:
|
|
return f"Whisper ({self.whisper_model_size})"
|
|
case ModelType.WHISPER_CPP:
|
|
return f"Whisper.cpp ({self.whisper_model_size})"
|
|
case ModelType.HUGGING_FACE:
|
|
return f"Hugging Face ({self.hugging_face_model_id})"
|
|
case ModelType.FASTER_WHISPER:
|
|
return f"Faster Whisper ({self.whisper_model_size})"
|
|
case ModelType.OPEN_AI_WHISPER_API:
|
|
return "OpenAI Whisper API"
|
|
case _:
|
|
raise Exception("Unknown model type")
|
|
|
|
def is_deletable(self):
|
|
return (
|
|
self.model_type == ModelType.WHISPER
|
|
or self.model_type == ModelType.WHISPER_CPP
|
|
or self.model_type == ModelType.FASTER_WHISPER
|
|
) and self.get_local_model_path() is not None
|
|
|
|
def open_file_location(self):
|
|
model_path = self.get_local_model_path()
|
|
|
|
if (self.model_type == ModelType.HUGGING_FACE
|
|
or self.model_type == ModelType.FASTER_WHISPER):
|
|
model_path = os.path.dirname(model_path)
|
|
|
|
if model_path is None:
|
|
return
|
|
self.open_path(path=os.path.dirname(model_path))
|
|
|
|
@staticmethod
|
|
def default():
|
|
model_type = next(
|
|
model_type for model_type in ModelType if model_type.is_available()
|
|
)
|
|
return TranscriptionModel(model_type=model_type)
|
|
|
|
@staticmethod
|
|
def open_path(path: str):
|
|
if sys.platform == "win32":
|
|
os.startfile(path)
|
|
else:
|
|
opener = "open" if sys.platform == "darwin" else "xdg-open"
|
|
subprocess.call([opener, path])
|
|
|
|
def delete_local_file(self):
|
|
model_path = self.get_local_model_path()
|
|
|
|
if (self.model_type == ModelType.HUGGING_FACE
|
|
or self.model_type == ModelType.FASTER_WHISPER):
|
|
model_path = os.path.dirname(os.path.dirname(model_path))
|
|
|
|
logging.debug("Deleting model directory: %s", model_path)
|
|
|
|
shutil.rmtree(model_path, ignore_errors=True)
|
|
return
|
|
|
|
logging.debug("Deleting model file: %s", model_path)
|
|
os.remove(model_path)
|
|
|
|
def get_local_model_path(self) -> Optional[str]:
|
|
if self.model_type == ModelType.WHISPER_CPP:
|
|
file_path = get_whisper_cpp_file_path(size=self.whisper_model_size)
|
|
if not file_path or not os.path.exists(file_path) or not os.path.isfile(file_path):
|
|
return None
|
|
return file_path
|
|
|
|
if self.model_type == ModelType.WHISPER:
|
|
file_path = get_whisper_file_path(size=self.whisper_model_size)
|
|
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
|
return None
|
|
return file_path
|
|
|
|
if self.model_type == ModelType.FASTER_WHISPER:
|
|
try:
|
|
return download_faster_whisper_model(
|
|
model=self, local_files_only=True
|
|
)
|
|
except (ValueError, FileNotFoundError):
|
|
return None
|
|
|
|
if self.model_type == ModelType.OPEN_AI_WHISPER_API:
|
|
return ""
|
|
|
|
if self.model_type == ModelType.HUGGING_FACE:
|
|
try:
|
|
return huggingface_hub.snapshot_download(
|
|
self.hugging_face_model_id,
|
|
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
|
|
local_files_only=True,
|
|
cache_dir=model_root_dir,
|
|
etag_timeout=60
|
|
)
|
|
except (ValueError, FileNotFoundError):
|
|
return None
|
|
|
|
raise Exception("Unknown model type")
|
|
|
|
|
|
WHISPER_CPP_REPO_ID = "ggerganov/whisper.cpp"
|
|
WHISPER_CPP_LUMII_REPO_ID = "RaivisDejus/whisper.cpp-lv"
|
|
|
|
|
|
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
|
|
if size == WhisperModelSize.CUSTOM:
|
|
return os.path.join(model_root_dir, f"ggml-model-whisper-custom.bin")
|
|
|
|
repo_id = WHISPER_CPP_REPO_ID
|
|
|
|
if size == WhisperModelSize.LUMII:
|
|
repo_id = WHISPER_CPP_LUMII_REPO_ID
|
|
|
|
model_filename = f"ggml-{size.to_whisper_cpp_model_size()}.bin"
|
|
|
|
try:
|
|
model_path = huggingface_hub.snapshot_download(
|
|
repo_id=repo_id,
|
|
allow_patterns=[model_filename],
|
|
local_files_only=True,
|
|
cache_dir=model_root_dir,
|
|
etag_timeout=60
|
|
)
|
|
|
|
return os.path.join(model_path, model_filename)
|
|
except LocalEntryNotFoundError:
|
|
return ''
|
|
|
|
|
|
def get_whisper_file_path(size: WhisperModelSize) -> str:
|
|
root_dir = os.path.join(model_root_dir, "whisper")
|
|
|
|
if size == WhisperModelSize.CUSTOM:
|
|
return os.path.join(root_dir, "custom")
|
|
|
|
url = whisper._MODELS[size.value]
|
|
return os.path.join(root_dir, os.path.basename(url))
|
|
|
|
|
|
class HuggingfaceDownloadMonitor:
|
|
def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size: int):
|
|
self.model_root = model_root
|
|
self.progress = progress
|
|
self.total_file_size = round(total_file_size * 1.1) # To keep dialog open even if it reports 100%
|
|
self.incomplete_download_root = None
|
|
self.stop_event = threading.Event()
|
|
self.monitor_thread = None
|
|
self.set_download_roots()
|
|
|
|
def set_download_roots(self):
|
|
normalized_model_root = os.path.normpath(self.model_root)
|
|
two_dirs_up = os.path.normpath(os.path.join(normalized_model_root, "..", ".."))
|
|
self.incomplete_download_root = os.path.normpath(os.path.join(two_dirs_up, "blobs"))
|
|
|
|
def clean_tmp_files(self):
|
|
for filename in os.listdir(model_root_dir):
|
|
if filename.startswith("tmp"):
|
|
os.remove(os.path.join(model_root_dir, filename))
|
|
|
|
def monitor_file_size(self):
|
|
while not self.stop_event.is_set():
|
|
if model_root_dir is not None:
|
|
for filename in os.listdir(model_root_dir):
|
|
if filename.startswith("tmp"):
|
|
file_size = os.path.getsize(os.path.join(model_root_dir, filename))
|
|
self.progress.emit((file_size, self.total_file_size))
|
|
|
|
for filename in os.listdir(self.incomplete_download_root):
|
|
if filename.endswith(".incomplete"):
|
|
file_size = os.path.getsize(os.path.join(self.incomplete_download_root, filename))
|
|
self.progress.emit((file_size, self.total_file_size))
|
|
|
|
time.sleep(2)
|
|
|
|
def start_monitoring(self):
|
|
self.clean_tmp_files()
|
|
self.monitor_thread = threading.Thread(target=self.monitor_file_size)
|
|
self.monitor_thread.start()
|
|
|
|
def stop_monitoring(self):
|
|
self.progress.emit((self.total_file_size, self.total_file_size))
|
|
|
|
if self.monitor_thread is not None:
|
|
self.stop_event.set()
|
|
self.monitor_thread.join()
|
|
|
|
|
|
def get_file_size(url):
|
|
response = requests.head(url, allow_redirects=True)
|
|
response.raise_for_status()
|
|
return int(response.headers['Content-Length'])
|
|
|
|
|
|
def download_from_huggingface(
|
|
repo_id: str,
|
|
allow_patterns: List[str],
|
|
progress: pyqtSignal(tuple),
|
|
num_large_files: int = 1
|
|
):
|
|
progress.emit((0, 100))
|
|
|
|
try:
|
|
model_root = huggingface_hub.snapshot_download(
|
|
repo_id,
|
|
allow_patterns=allow_patterns[num_large_files:], # all, but largest
|
|
cache_dir=model_root_dir,
|
|
etag_timeout=60
|
|
)
|
|
except Exception as exc:
|
|
logging.exception(exc)
|
|
return ""
|
|
|
|
progress.emit((1, 100))
|
|
|
|
largest_file_size = 0
|
|
for pattern in allow_patterns[:num_large_files]:
|
|
try:
|
|
file_url = huggingface_hub.hf_hub_url(repo_id, pattern)
|
|
file_size = get_file_size(file_url)
|
|
|
|
if file_size > largest_file_size:
|
|
largest_file_size = file_size
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
continue
|
|
|
|
model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, largest_file_size)
|
|
model_download_monitor.start_monitoring()
|
|
|
|
try:
|
|
huggingface_hub.snapshot_download(
|
|
repo_id,
|
|
allow_patterns=allow_patterns[:num_large_files], # largest
|
|
cache_dir=model_root_dir,
|
|
etag_timeout=60
|
|
)
|
|
except Exception as exc:
|
|
logging.exception(exc)
|
|
model_download_monitor.stop_monitoring()
|
|
# Cleanup to prevent incomplete downloads errors
|
|
if os.path.exists(model_root):
|
|
shutil.rmtree(model_root)
|
|
return ""
|
|
|
|
model_download_monitor.stop_monitoring()
|
|
|
|
return model_root
|
|
|
|
|
|
def download_faster_whisper_model(
|
|
model: TranscriptionModel, local_files_only=False, progress: pyqtSignal(tuple) = None
|
|
):
|
|
size = model.whisper_model_size.to_faster_whisper_model_size()
|
|
custom_repo_id = model.hugging_face_model_id
|
|
|
|
if size == WhisperModelSize.CUSTOM and custom_repo_id == "":
|
|
raise ValueError("Custom model id is not provided")
|
|
|
|
if size == WhisperModelSize.CUSTOM:
|
|
repo_id = custom_repo_id
|
|
# Replicating models from faster-whisper code https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/utils.py#L29
|
|
elif size == WhisperModelSize.LARGEV3TURBO:
|
|
repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
|
|
else:
|
|
repo_id = "Systran/faster-whisper-%s" % size
|
|
|
|
allow_patterns = [
|
|
"model.bin", # largest by size first
|
|
"pytorch_model.bin", # possible alternative model filename
|
|
"config.json",
|
|
"preprocessor_config.json",
|
|
"tokenizer.json",
|
|
"vocabulary.*",
|
|
]
|
|
|
|
if local_files_only:
|
|
return huggingface_hub.snapshot_download(
|
|
repo_id,
|
|
allow_patterns=allow_patterns,
|
|
local_files_only=True,
|
|
cache_dir=model_root_dir,
|
|
etag_timeout=60
|
|
)
|
|
|
|
return download_from_huggingface(
|
|
repo_id,
|
|
allow_patterns=allow_patterns,
|
|
progress=progress,
|
|
num_large_files=2
|
|
)
|
|
|
|
|
|
class ModelDownloader(QRunnable):
|
|
class Signals(QObject):
|
|
finished = pyqtSignal(str)
|
|
progress = pyqtSignal(tuple) # (current, total)
|
|
error = pyqtSignal(str)
|
|
|
|
def __init__(self, model: TranscriptionModel, custom_model_url: Optional[str] = None):
|
|
super().__init__()
|
|
|
|
self.is_coreml_supported = platform.system() == "Darwin" and platform.machine() == "arm64"
|
|
self.signals = self.Signals()
|
|
self.model = model
|
|
self.stopped = False
|
|
self.custom_model_url = custom_model_url
|
|
|
|
def run(self) -> None:
|
|
logging.debug("Downloading model: %s, %s", self.model, self.model.hugging_face_model_id)
|
|
|
|
if self.model.model_type == ModelType.WHISPER_CPP:
|
|
if self.custom_model_url:
|
|
url = self.custom_model_url
|
|
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
|
|
return self.download_model_to_path(url=url, file_path=file_path)
|
|
|
|
repo_id = WHISPER_CPP_REPO_ID
|
|
|
|
if self.model.whisper_model_size == WhisperModelSize.LUMII:
|
|
repo_id = WHISPER_CPP_LUMII_REPO_ID
|
|
|
|
model_name = self.model.whisper_model_size.to_whisper_cpp_model_size()
|
|
|
|
whisper_cpp_model_files = [
|
|
f"ggml-{model_name}.bin",
|
|
"README.md"
|
|
]
|
|
num_large_files = 1
|
|
if self.is_coreml_supported:
|
|
whisper_cpp_model_files = [
|
|
f"ggml-{model_name}.bin",
|
|
f"ggml-{model_name}-encoder.mlmodelc.zip",
|
|
"README.md"
|
|
]
|
|
num_large_files = 2
|
|
|
|
model_path = download_from_huggingface(
|
|
repo_id=repo_id,
|
|
allow_patterns=whisper_cpp_model_files,
|
|
progress=self.signals.progress,
|
|
num_large_files=num_large_files
|
|
)
|
|
|
|
if self.is_coreml_supported:
|
|
with zipfile.ZipFile(
|
|
os.path.join(model_path, f"ggml-{model_name}-encoder.mlmodelc.zip"), 'r') as zip_ref:
|
|
zip_ref.extractall(model_path)
|
|
|
|
self.signals.finished.emit(os.path.join(model_path, f"ggml-{model_name}.bin"))
|
|
return
|
|
|
|
if self.model.model_type == ModelType.WHISPER:
|
|
url = whisper._MODELS[self.model.whisper_model_size.value]
|
|
file_path = get_whisper_file_path(size=self.model.whisper_model_size)
|
|
expected_sha256 = url.split("/")[-2]
|
|
return self.download_model_to_path(
|
|
url=url, file_path=file_path, expected_sha256=expected_sha256
|
|
)
|
|
|
|
if self.model.model_type == ModelType.FASTER_WHISPER:
|
|
model_path = download_faster_whisper_model(
|
|
model=self.model,
|
|
progress=self.signals.progress,
|
|
)
|
|
|
|
if model_path == "":
|
|
self.signals.error.emit(_("Error"))
|
|
|
|
self.signals.finished.emit(model_path)
|
|
return
|
|
|
|
if self.model.model_type == ModelType.HUGGING_FACE:
|
|
model_path = download_from_huggingface(
|
|
self.model.hugging_face_model_id,
|
|
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
|
|
progress=self.signals.progress,
|
|
num_large_files=4
|
|
)
|
|
|
|
if model_path == "":
|
|
self.signals.error.emit(_("Error"))
|
|
|
|
self.signals.finished.emit(model_path)
|
|
return
|
|
|
|
if self.model.model_type == ModelType.OPEN_AI_WHISPER_API:
|
|
self.signals.finished.emit("")
|
|
return
|
|
|
|
raise Exception("Invalid model type: " + self.model.model_type.value)
|
|
|
|
def download_model_to_path(
|
|
self, url: str, file_path: str, expected_sha256: Optional[str] = None
|
|
):
|
|
try:
|
|
downloaded = self.download_model(url, file_path, expected_sha256)
|
|
if downloaded:
|
|
self.signals.finished.emit(file_path)
|
|
except requests.RequestException:
|
|
self.signals.error.emit(_("A connection error occurred"))
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
logging.exception("")
|
|
except Exception as exc:
|
|
self.signals.error.emit(str(exc))
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
logging.exception(exc)
|
|
|
|
def download_model(
|
|
self, url: str, file_path: str, expected_sha256: Optional[str]
|
|
) -> bool:
|
|
logging.debug(f"Downloading model from {url} to {file_path}")
|
|
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
|
|
if os.path.exists(file_path) and not os.path.isfile(file_path):
|
|
raise RuntimeError(f"{file_path} exists and is not a regular file")
|
|
|
|
if os.path.isfile(file_path):
|
|
if expected_sha256 is None:
|
|
return True
|
|
|
|
model_bytes = open(file_path, "rb").read()
|
|
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
|
|
if model_sha256 == expected_sha256:
|
|
return True
|
|
else:
|
|
warnings.warn(
|
|
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file"
|
|
)
|
|
|
|
# Downloads the model using the requests module instead of urllib to
|
|
# use the certs from certifi when the app is running in frozen mode
|
|
with requests.get(url, stream=True, timeout=15) as source, open(
|
|
file_path, "wb"
|
|
) as output:
|
|
source.raise_for_status()
|
|
total_size = float(source.headers.get("Content-Length", 0))
|
|
current = 0.0
|
|
self.signals.progress.emit((current, total_size))
|
|
for chunk in source.iter_content(chunk_size=8192):
|
|
if self.stopped:
|
|
return False
|
|
output.write(chunk)
|
|
current += len(chunk)
|
|
self.signals.progress.emit((current, total_size))
|
|
|
|
if expected_sha256 is not None:
|
|
model_bytes = open(file_path, "rb").read()
|
|
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
|
raise RuntimeError(
|
|
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
|
|
"model."
|
|
)
|
|
|
|
logging.debug("Downloaded model")
|
|
|
|
return True
|
|
|
|
def cancel(self):
|
|
self.stopped = True
|
|
|
|
|
|
def get_custom_api_whisper_model(base_url: str):
|
|
if "api.groq.com" in base_url:
|
|
return "whisper-large-v3"
|
|
|
|
return "whisper-1"
|