buzz/buzz/model_loader.py
2024-08-27 17:43:37 +00:00

573 lines
19 KiB
Python

import enum
import hashlib
import logging
import os
import time
import threading
import shutil
import subprocess
import sys
import tempfile
import warnings
import platform
from dataclasses import dataclass
from typing import Optional, List
import requests
from PyQt6.QtCore import QObject, pyqtSignal, QRunnable
from platformdirs import user_cache_dir
import faster_whisper
import whisper
import huggingface_hub
from buzz.locale import _
# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
# the DLL in some envs is fixed.
LOADED_WHISPER_CPP_BINARY = False
try:
import buzz.whisper_cpp as whisper_cpp # noqa: F401
LOADED_WHISPER_CPP_BINARY = True
except ImportError:
logging.exception("")
model_root_dir = user_cache_dir("Buzz")
model_root_dir = os.path.join(model_root_dir, "models")
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
os.makedirs(model_root_dir, exist_ok=True)
logging.debug("Model root directory: %s", model_root_dir)
class WhisperModelSize(str, enum.Enum):
TINY = "tiny"
BASE = "base"
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
LARGEV2 = "large-v2"
LARGEV3 = "large-v3"
CUSTOM = "custom"
def to_faster_whisper_model_size(self) -> str:
if self == WhisperModelSize.LARGE:
return "large-v1"
return self.value
def to_whisper_cpp_model_size(self) -> str:
if self == WhisperModelSize.LARGE:
return "large-v1"
return self.value
def __str__(self):
return self.value.capitalize()
class ModelType(enum.Enum):
WHISPER = "Whisper"
WHISPER_CPP = "Whisper.cpp"
HUGGING_FACE = "Hugging Face"
FASTER_WHISPER = "Faster Whisper"
OPEN_AI_WHISPER_API = "OpenAI Whisper API"
@property
def supports_initial_prompt(self):
return self in (
ModelType.WHISPER,
ModelType.WHISPER_CPP,
ModelType.OPEN_AI_WHISPER_API,
ModelType.FASTER_WHISPER,
)
def is_available(self):
if (
# Hide Whisper.cpp option if whisper.dll did not load correctly.
# See: https://github.com/chidiwilliams/buzz/issues/274,
# https://github.com/chidiwilliams/buzz/issues/197
(self == ModelType.WHISPER_CPP and not LOADED_WHISPER_CPP_BINARY)
):
return False
elif (
# Hide Faster Whisper option on macOS x86_64
# See: https://github.com/SYSTRAN/faster-whisper/issues/541
(self == ModelType.FASTER_WHISPER
and platform.system() == "Darwin" and platform.machine() == "x86_64")
):
return False
return True
def is_manually_downloadable(self):
return self in (
ModelType.WHISPER,
ModelType.WHISPER_CPP,
ModelType.FASTER_WHISPER,
)
HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
"model.safetensors", # largest by size first
"pytorch_model.bin",
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
"added_tokens.json",
"config.json",
"generation_config.json",
"merges.txt",
"normalizer.json",
"preprocessor_config.json",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
"vocab.json",
]
@dataclass()
class TranscriptionModel:
def __init__(
self,
model_type: ModelType = ModelType.WHISPER,
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY,
hugging_face_model_id: Optional[str] = ""
):
self.model_type = model_type
self.whisper_model_size = whisper_model_size
self.hugging_face_model_id = hugging_face_model_id
def __str__(self):
match self.model_type:
case ModelType.WHISPER:
return f"Whisper ({self.whisper_model_size})"
case ModelType.WHISPER_CPP:
return f"Whisper.cpp ({self.whisper_model_size})"
case ModelType.HUGGING_FACE:
return f"Hugging Face ({self.hugging_face_model_id})"
case ModelType.FASTER_WHISPER:
return f"Faster Whisper ({self.whisper_model_size})"
case ModelType.OPEN_AI_WHISPER_API:
return "OpenAI Whisper API"
case _:
raise Exception("Unknown model type")
def is_deletable(self):
return (
self.model_type == ModelType.WHISPER
or self.model_type == ModelType.WHISPER_CPP
or self.model_type == ModelType.FASTER_WHISPER
) and self.get_local_model_path() is not None
def open_file_location(self):
model_path = self.get_local_model_path()
if (self.model_type == ModelType.HUGGING_FACE
or self.model_type == ModelType.FASTER_WHISPER):
model_path = os.path.dirname(model_path)
if model_path is None:
return
self.open_path(path=os.path.dirname(model_path))
@staticmethod
def default():
model_type = next(
model_type for model_type in ModelType if model_type.is_available()
)
return TranscriptionModel(model_type=model_type)
@staticmethod
def open_path(path: str):
if sys.platform == "win32":
os.startfile(path)
else:
opener = "open" if sys.platform == "darwin" else "xdg-open"
subprocess.call([opener, path])
def delete_local_file(self):
model_path = self.get_local_model_path()
if (self.model_type == ModelType.HUGGING_FACE
or self.model_type == ModelType.FASTER_WHISPER):
model_path = os.path.dirname(os.path.dirname(model_path))
logging.debug("Deleting model directory: %s", model_path)
shutil.rmtree(model_path, ignore_errors=True)
return
logging.debug("Deleting model file: %s", model_path)
os.remove(model_path)
def get_local_model_path(self) -> Optional[str]:
if self.model_type == ModelType.WHISPER_CPP:
file_path = get_whisper_cpp_file_path(size=self.whisper_model_size)
if not os.path.exists(file_path) or not os.path.isfile(file_path):
return None
return file_path
if self.model_type == ModelType.WHISPER:
file_path = get_whisper_file_path(size=self.whisper_model_size)
if not os.path.exists(file_path) or not os.path.isfile(file_path):
return None
return file_path
if self.model_type == ModelType.FASTER_WHISPER:
try:
return download_faster_whisper_model(
model=self, local_files_only=True
)
except (ValueError, FileNotFoundError):
return None
if self.model_type == ModelType.OPEN_AI_WHISPER_API:
return ""
if self.model_type == ModelType.HUGGING_FACE:
try:
return huggingface_hub.snapshot_download(
self.hugging_face_model_id,
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
local_files_only=True,
cache_dir=model_root_dir,
etag_timeout=60
)
except (ValueError, FileNotFoundError):
return None
raise Exception("Unknown model type")
WHISPER_CPP_MODELS_SHA256 = {
"tiny": "be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21",
"base": "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe",
"small": "1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b",
"medium": "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208",
"large-v1": "7d99f41a10525d0206bddadd86760181fa920438b6b33237e3118ff6c83bb53d",
"large-v2": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487",
"large-v3": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2",
"custom": None,
}
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
return os.path.join(model_root_dir, f"ggml-model-whisper-{size.value}.bin")
def get_whisper_file_path(size: WhisperModelSize) -> str:
root_dir = os.path.join(model_root_dir, "whisper")
if size == WhisperModelSize.CUSTOM:
return os.path.join(root_dir, "custom")
url = whisper._MODELS[size.value]
return os.path.join(root_dir, os.path.basename(url))
class HuggingfaceDownloadMonitor:
def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size: int):
self.model_root = model_root
self.progress = progress
self.total_file_size = total_file_size
self.incomplete_download_root = None
self.stop_event = threading.Event()
self.monitor_thread = None
self.set_download_roots()
def set_download_roots(self):
normalized_model_root = os.path.normpath(self.model_root)
two_dirs_up = os.path.normpath(os.path.join(normalized_model_root, "..", ".."))
self.incomplete_download_root = os.path.normpath(os.path.join(two_dirs_up, "blobs"))
def clean_tmp_files(self):
for filename in os.listdir(model_root_dir):
if filename.startswith("tmp"):
os.remove(os.path.join(model_root_dir, filename))
def monitor_file_size(self):
while not self.stop_event.is_set():
if model_root_dir is not None:
for filename in os.listdir(model_root_dir):
if filename.startswith("tmp"):
file_size = os.path.getsize(os.path.join(model_root_dir, filename))
self.progress.emit((file_size, self.total_file_size))
for filename in os.listdir(self.incomplete_download_root):
if filename.endswith(".incomplete"):
file_size = os.path.getsize(os.path.join(self.incomplete_download_root, filename))
self.progress.emit((file_size, self.total_file_size))
time.sleep(2)
def start_monitoring(self):
self.clean_tmp_files()
self.monitor_thread = threading.Thread(target=self.monitor_file_size)
self.monitor_thread.start()
def stop_monitoring(self):
self.progress.emit((self.total_file_size, self.total_file_size))
if self.monitor_thread is not None:
self.stop_event.set()
self.monitor_thread.join()
def get_file_size(url):
response = requests.head(url, allow_redirects=True)
response.raise_for_status()
return int(response.headers['Content-Length'])
def download_from_huggingface(
repo_id: str,
allow_patterns: List[str],
progress: pyqtSignal(tuple),
num_large_files: int = 1
):
progress.emit((0, 100))
try:
model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[num_large_files:], # all, but largest
cache_dir=model_root_dir,
etag_timeout=60
)
except Exception as exc:
logging.exception(exc)
return ""
progress.emit((1, 100))
largest_file_size = 0
for pattern in allow_patterns[:num_large_files]:
try:
file_url = huggingface_hub.hf_hub_url(repo_id, pattern)
file_size = get_file_size(file_url)
if file_size > largest_file_size:
largest_file_size = file_size
except requests.exceptions.RequestException as e:
continue
model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, largest_file_size)
model_download_monitor.start_monitoring()
try:
huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:num_large_files], # largest
cache_dir=model_root_dir,
etag_timeout=60
)
except Exception as exc:
logging.exception(exc)
model_download_monitor.stop_monitoring()
return ""
model_download_monitor.stop_monitoring()
return model_root
def download_faster_whisper_model(
model: TranscriptionModel, local_files_only=False, progress: pyqtSignal(tuple) = None
):
size = model.whisper_model_size.to_faster_whisper_model_size()
custom_repo_id = model.hugging_face_model_id
if size != WhisperModelSize.CUSTOM and size not in faster_whisper.utils._MODELS:
raise ValueError(
"Invalid model size '%s', expected one of: %s"
% (size, ", ".join(faster_whisper.utils._MODELS))
)
if size == WhisperModelSize.CUSTOM and custom_repo_id == "":
raise ValueError("Custom model id is not provided")
if size == WhisperModelSize.CUSTOM:
repo_id = custom_repo_id
elif size == WhisperModelSize.LARGEV3:
repo_id = "Systran/faster-whisper-large-v3"
else:
repo_id = "guillaumekln/faster-whisper-%s" % size
allow_patterns = [
"model.bin", # largest by size first
"pytorch_model.bin", # possible alternative model filename
"config.json",
"tokenizer.json",
"vocabulary.txt",
"vocabulary.json",
]
if local_files_only:
return huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns,
local_files_only=True,
cache_dir=model_root_dir,
etag_timeout=60
)
return download_from_huggingface(
repo_id,
allow_patterns=allow_patterns,
progress=progress,
num_large_files=2
)
class ModelDownloader(QRunnable):
class Signals(QObject):
finished = pyqtSignal(str)
progress = pyqtSignal(tuple) # (current, total)
error = pyqtSignal(str)
def __init__(self, model: TranscriptionModel, custom_model_url: Optional[str] = None):
super().__init__()
self.signals = self.Signals()
self.model = model
self.stopped = False
self.custom_model_url = custom_model_url
def run(self) -> None:
logging.debug("Downloading model: %s, %s", self.model, self.model.hugging_face_model_id)
if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.to_whisper_cpp_model_size()
if self.custom_model_url:
url = self.custom_model_url
else:
url = huggingface_hub.hf_hub_url(
repo_id="ggerganov/whisper.cpp",
filename=f"ggml-{model_name}.bin",
)
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
return self.download_model_to_path(
url=url, file_path=file_path, expected_sha256=expected_sha256
)
if self.model.model_type == ModelType.WHISPER:
url = whisper._MODELS[self.model.whisper_model_size.value]
file_path = get_whisper_file_path(size=self.model.whisper_model_size)
expected_sha256 = url.split("/")[-2]
return self.download_model_to_path(
url=url, file_path=file_path, expected_sha256=expected_sha256
)
if self.model.model_type == ModelType.FASTER_WHISPER:
model_path = download_faster_whisper_model(
model=self.model,
progress=self.signals.progress,
)
if model_path == "":
self.signals.error.emit(_("Error"))
self.signals.finished.emit(model_path)
return
if self.model.model_type == ModelType.HUGGING_FACE:
model_path = download_from_huggingface(
self.model.hugging_face_model_id,
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
progress=self.signals.progress,
num_large_files=4
)
self.signals.finished.emit(model_path)
return
if self.model.model_type == ModelType.OPEN_AI_WHISPER_API:
self.signals.finished.emit("")
return
raise Exception("Invalid model type: " + self.model.model_type.value)
def download_model_to_path(
self, url: str, file_path: str, expected_sha256: Optional[str]
):
try:
downloaded = self.download_model(url, file_path, expected_sha256)
if downloaded:
self.signals.finished.emit(file_path)
except requests.RequestException:
self.signals.error.emit(_("A connection error occurred"))
logging.exception("")
except Exception as exc:
self.signals.error.emit(str(exc))
logging.exception(exc)
def download_model(
self, url: str, file_path: str, expected_sha256: Optional[str]
) -> bool:
logging.debug(f"Downloading model from {url} to {file_path}")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
if os.path.exists(file_path) and not os.path.isfile(file_path):
raise RuntimeError(f"{file_path} exists and is not a regular file")
if os.path.isfile(file_path):
if expected_sha256 is None:
return True
model_bytes = open(file_path, "rb").read()
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
if model_sha256 == expected_sha256:
return True
else:
warnings.warn(
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file"
)
tmp_file = tempfile.mktemp()
logging.debug("Downloading to temporary file = %s", tmp_file)
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(
tmp_file, "wb"
) as output:
source.raise_for_status()
total_size = float(source.headers.get("Content-Length", 0))
current = 0.0
self.signals.progress.emit((current, total_size))
for chunk in source.iter_content(chunk_size=8192):
if self.stopped:
return False
output.write(chunk)
current += len(chunk)
self.signals.progress.emit((current, total_size))
if expected_sha256 is not None:
model_bytes = open(tmp_file, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
"model."
)
logging.debug("Downloaded model")
# https://github.com/chidiwilliams/buzz/issues/454
shutil.move(tmp_file, file_path)
logging.debug("Moved file from %s to %s", tmp_file, file_path)
return True
def cancel(self):
self.stopped = True
def get_custom_api_whisper_model(base_url: str):
if "api.groq.com" in base_url:
return "whisper-large-v3"
return "whisper-1"