mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 14:45:46 +01:00
Refactored model downloading (#761)
This commit is contained in:
parent
7820952616
commit
731efd7d38
2 changed files with 125 additions and 30 deletions
|
|
@ -2,18 +2,19 @@ import enum
|
|||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import requests
|
||||
from PyQt6.QtCore import QObject, pyqtSignal, QRunnable
|
||||
from platformdirs import user_cache_dir
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import faster_whisper
|
||||
import whisper
|
||||
|
|
@ -86,11 +87,11 @@ class ModelType(enum.Enum):
|
|||
|
||||
|
||||
HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
|
||||
"model.safetensors", # largest by size first
|
||||
"added_tokens.json",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"merges.txt",
|
||||
"model.safetensors",
|
||||
"normalizer.json",
|
||||
"preprocessor_config.json",
|
||||
"special_tokens_map.json",
|
||||
|
|
@ -198,10 +199,6 @@ WHISPER_CPP_MODELS_SHA256 = {
|
|||
}
|
||||
|
||||
|
||||
def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
|
||||
return f"https://huggingface.co/{author}/{repository_name}/resolve/bf8b606c2fcd9173605cdf6bd2ac8a75a8141b6c/{filename}"
|
||||
|
||||
|
||||
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
|
||||
root_dir = user_cache_dir("Buzz")
|
||||
return os.path.join(root_dir, f"ggml-model-whisper-{size.value}.bin")
|
||||
|
|
@ -215,8 +212,88 @@ def get_whisper_file_path(size: WhisperModelSize) -> str:
|
|||
return os.path.join(root_dir, os.path.basename(url))
|
||||
|
||||
|
||||
class HuggingfaceDownloadMonitor:
|
||||
def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size: int):
|
||||
self.model_root = model_root
|
||||
self.progress = progress
|
||||
self.total_file_size = total_file_size
|
||||
self.tmp_download_root = self.get_tmp_download_root(model_root)
|
||||
self.stop_event = threading.Event()
|
||||
self.monitor_thread = None
|
||||
|
||||
@staticmethod
|
||||
def get_tmp_download_root(model_root):
|
||||
normalized_model_root = os.path.normpath(model_root)
|
||||
normalized_hub_path = os.path.normpath("huggingface/hub/")
|
||||
index = normalized_model_root.find(normalized_hub_path)
|
||||
if index == -1:
|
||||
raise ValueError(f"Invalid model_root, '{normalized_hub_path}' not found")
|
||||
return normalized_model_root[:index + len(normalized_hub_path)]
|
||||
|
||||
def clean_tmp_files(self):
|
||||
for filename in os.listdir(self.tmp_download_root):
|
||||
if filename.startswith("tmp"):
|
||||
os.remove(os.path.join(self.tmp_download_root, filename))
|
||||
|
||||
def monitor_file_size(self):
|
||||
while not self.stop_event.is_set():
|
||||
for filename in os.listdir(self.tmp_download_root):
|
||||
if filename.startswith("tmp"):
|
||||
file_size = os.path.getsize(os.path.join(self.tmp_download_root, filename))
|
||||
self.progress.emit((file_size, self.total_file_size))
|
||||
time.sleep(2)
|
||||
|
||||
def start_monitoring(self):
|
||||
self.clean_tmp_files()
|
||||
self.monitor_thread = threading.Thread(target=self.monitor_file_size)
|
||||
self.monitor_thread.start()
|
||||
|
||||
def stop_monitoring(self):
|
||||
self.progress.emit((self.total_file_size, self.total_file_size))
|
||||
|
||||
if self.monitor_thread is not None:
|
||||
self.stop_event.set()
|
||||
self.monitor_thread.join()
|
||||
|
||||
|
||||
def get_file_size(url):
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return int(response.headers['Content-Length'])
|
||||
|
||||
|
||||
def download_from_huggingface(
|
||||
repo_id: str,
|
||||
allow_patterns: List[str],
|
||||
progress: pyqtSignal(tuple),
|
||||
):
|
||||
progress.emit((1, 100))
|
||||
|
||||
model_root = huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns[1:], # all, but largest
|
||||
)
|
||||
|
||||
progress.emit((1, 100))
|
||||
|
||||
largest_file_url = huggingface_hub.hf_hub_url(repo_id, allow_patterns[0])
|
||||
total_file_size = get_file_size(largest_file_url)
|
||||
|
||||
model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, total_file_size)
|
||||
model_download_monitor.start_monitoring()
|
||||
|
||||
huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns[:1], # largest
|
||||
)
|
||||
|
||||
model_download_monitor.stop_monitoring()
|
||||
|
||||
return model_root
|
||||
|
||||
|
||||
def download_faster_whisper_model(
|
||||
size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None
|
||||
size: str, local_files_only=False, progress: pyqtSignal(tuple) = None
|
||||
):
|
||||
if size not in faster_whisper.utils._MODELS:
|
||||
raise ValueError(
|
||||
|
|
@ -227,17 +304,23 @@ def download_faster_whisper_model(
|
|||
repo_id = "guillaumekln/faster-whisper-%s" % size
|
||||
|
||||
allow_patterns = [
|
||||
"model.bin", # largest by size first
|
||||
"config.json",
|
||||
"model.bin",
|
||||
"tokenizer.json",
|
||||
"vocabulary.txt",
|
||||
]
|
||||
|
||||
return huggingface_hub.snapshot_download(
|
||||
if local_files_only:
|
||||
return huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
return download_from_huggingface(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns,
|
||||
local_files_only=local_files_only,
|
||||
tqdm_class=tqdm_class,
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -257,9 +340,8 @@ class ModelDownloader(QRunnable):
|
|||
def run(self) -> None:
|
||||
if self.model.model_type == ModelType.WHISPER_CPP:
|
||||
model_name = self.model.whisper_model_size.value
|
||||
url = get_hugging_face_file_url(
|
||||
author="ggerganov",
|
||||
repository_name="whisper.cpp",
|
||||
url = huggingface_hub.hf_hub_url(
|
||||
repo_id="ggerganov/whisper.cpp",
|
||||
filename=f"ggml-{model_name}.bin",
|
||||
)
|
||||
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
|
||||
|
|
@ -276,31 +358,19 @@ class ModelDownloader(QRunnable):
|
|||
url=url, file_path=file_path, expected_sha256=expected_sha256
|
||||
)
|
||||
|
||||
progress = self.signals.progress
|
||||
|
||||
# gross abuse of power...
|
||||
class _tqdm(tqdm):
|
||||
def update(self, n: float | None = ...) -> bool | None:
|
||||
progress.emit((n, self.total))
|
||||
return super().update(n)
|
||||
|
||||
def close(self):
|
||||
progress.emit((self.n, self.total))
|
||||
return super().close()
|
||||
|
||||
if self.model.model_type == ModelType.FASTER_WHISPER:
|
||||
model_path = download_faster_whisper_model(
|
||||
size=self.model.whisper_model_size.to_faster_whisper_model_size(),
|
||||
tqdm_class=_tqdm,
|
||||
progress=self.signals.progress,
|
||||
)
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
||||
if self.model.model_type == ModelType.HUGGING_FACE:
|
||||
model_path = huggingface_hub.snapshot_download(
|
||||
model_path = download_from_huggingface(
|
||||
self.model.hugging_face_model_id,
|
||||
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
|
||||
tqdm_class=_tqdm
|
||||
progress=self.signals.progress,
|
||||
)
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
|
|
|||
25
tests/model_loader_test.py
Normal file
25
tests/model_loader_test.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import os
|
||||
import pytest
|
||||
|
||||
from buzz.model_loader import ModelDownloader,TranscriptionModel, ModelType, WhisperModelSize
|
||||
|
||||
|
||||
class TestModelLoader:
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
TranscriptionModel(
|
||||
model_type=ModelType.HUGGING_FACE,
|
||||
hugging_face_model_id="RaivisDejus/whisper-tiny-lv",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_download_model(self, model: TranscriptionModel):
|
||||
model_loader = ModelDownloader(model=model)
|
||||
model_loader.run()
|
||||
|
||||
model_path = model.get_local_model_path()
|
||||
|
||||
assert model_path is not None, "Model path is None"
|
||||
assert os.path.isdir(model_path), "Model path is not a directory"
|
||||
assert len(os.listdir(model_path)) > 0, "Model directory is empty"
|
||||
Loading…
Add table
Add a link
Reference in a new issue