Refactored model downloading (#761)

This commit is contained in:
Raivis Dejus 2024-05-28 09:07:00 +03:00 committed by GitHub
commit 731efd7d38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 125 additions and 30 deletions

View file

@ -2,18 +2,19 @@ import enum
import hashlib
import logging
import os
import time
import threading
import shutil
import subprocess
import sys
import tempfile
import warnings
from dataclasses import dataclass
from typing import Optional
from typing import Optional, List
import requests
from PyQt6.QtCore import QObject, pyqtSignal, QRunnable
from platformdirs import user_cache_dir
from tqdm.auto import tqdm
import faster_whisper
import whisper
@ -86,11 +87,11 @@ class ModelType(enum.Enum):
HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
"model.safetensors", # largest by size first
"added_tokens.json",
"config.json",
"generation_config.json",
"merges.txt",
"model.safetensors",
"normalizer.json",
"preprocessor_config.json",
"special_tokens_map.json",
@ -198,10 +199,6 @@ WHISPER_CPP_MODELS_SHA256 = {
}
def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
return f"https://huggingface.co/{author}/{repository_name}/resolve/bf8b606c2fcd9173605cdf6bd2ac8a75a8141b6c/{filename}"
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
root_dir = user_cache_dir("Buzz")
return os.path.join(root_dir, f"ggml-model-whisper-{size.value}.bin")
@ -215,8 +212,88 @@ def get_whisper_file_path(size: WhisperModelSize) -> str:
return os.path.join(root_dir, os.path.basename(url))
class HuggingfaceDownloadMonitor:
def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size: int):
self.model_root = model_root
self.progress = progress
self.total_file_size = total_file_size
self.tmp_download_root = self.get_tmp_download_root(model_root)
self.stop_event = threading.Event()
self.monitor_thread = None
@staticmethod
def get_tmp_download_root(model_root):
normalized_model_root = os.path.normpath(model_root)
normalized_hub_path = os.path.normpath("huggingface/hub/")
index = normalized_model_root.find(normalized_hub_path)
if index == -1:
raise ValueError(f"Invalid model_root, '{normalized_hub_path}' not found")
return normalized_model_root[:index + len(normalized_hub_path)]
def clean_tmp_files(self):
for filename in os.listdir(self.tmp_download_root):
if filename.startswith("tmp"):
os.remove(os.path.join(self.tmp_download_root, filename))
def monitor_file_size(self):
while not self.stop_event.is_set():
for filename in os.listdir(self.tmp_download_root):
if filename.startswith("tmp"):
file_size = os.path.getsize(os.path.join(self.tmp_download_root, filename))
self.progress.emit((file_size, self.total_file_size))
time.sleep(2)
def start_monitoring(self):
self.clean_tmp_files()
self.monitor_thread = threading.Thread(target=self.monitor_file_size)
self.monitor_thread.start()
def stop_monitoring(self):
self.progress.emit((self.total_file_size, self.total_file_size))
if self.monitor_thread is not None:
self.stop_event.set()
self.monitor_thread.join()
def get_file_size(url):
response = requests.head(url, allow_redirects=True)
response.raise_for_status()
return int(response.headers['Content-Length'])
def download_from_huggingface(
repo_id: str,
allow_patterns: List[str],
progress: pyqtSignal(tuple),
):
progress.emit((1, 100))
model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[1:], # all, but largest
)
progress.emit((1, 100))
largest_file_url = huggingface_hub.hf_hub_url(repo_id, allow_patterns[0])
total_file_size = get_file_size(largest_file_url)
model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, total_file_size)
model_download_monitor.start_monitoring()
huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:1], # largest
)
model_download_monitor.stop_monitoring()
return model_root
def download_faster_whisper_model(
size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None
size: str, local_files_only=False, progress: pyqtSignal(tuple) = None
):
if size not in faster_whisper.utils._MODELS:
raise ValueError(
@ -227,17 +304,23 @@ def download_faster_whisper_model(
repo_id = "guillaumekln/faster-whisper-%s" % size
allow_patterns = [
"model.bin", # largest by size first
"config.json",
"model.bin",
"tokenizer.json",
"vocabulary.txt",
]
return huggingface_hub.snapshot_download(
if local_files_only:
return huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns,
local_files_only=True,
)
return download_from_huggingface(
repo_id,
allow_patterns=allow_patterns,
local_files_only=local_files_only,
tqdm_class=tqdm_class,
progress=progress,
)
@ -257,9 +340,8 @@ class ModelDownloader(QRunnable):
def run(self) -> None:
if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.value
url = get_hugging_face_file_url(
author="ggerganov",
repository_name="whisper.cpp",
url = huggingface_hub.hf_hub_url(
repo_id="ggerganov/whisper.cpp",
filename=f"ggml-{model_name}.bin",
)
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
@ -276,31 +358,19 @@ class ModelDownloader(QRunnable):
url=url, file_path=file_path, expected_sha256=expected_sha256
)
progress = self.signals.progress
# gross abuse of power...
class _tqdm(tqdm):
def update(self, n: float | None = ...) -> bool | None:
progress.emit((n, self.total))
return super().update(n)
def close(self):
progress.emit((self.n, self.total))
return super().close()
if self.model.model_type == ModelType.FASTER_WHISPER:
model_path = download_faster_whisper_model(
size=self.model.whisper_model_size.to_faster_whisper_model_size(),
tqdm_class=_tqdm,
progress=self.signals.progress,
)
self.signals.finished.emit(model_path)
return
if self.model.model_type == ModelType.HUGGING_FACE:
model_path = huggingface_hub.snapshot_download(
model_path = download_from_huggingface(
self.model.hugging_face_model_id,
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
tqdm_class=_tqdm
progress=self.signals.progress,
)
self.signals.finished.emit(model_path)
return

View file

@ -0,0 +1,25 @@
import os
import pytest
from buzz.model_loader import ModelDownloader,TranscriptionModel, ModelType, WhisperModelSize
class TestModelLoader:
@pytest.mark.parametrize(
"model",
[
TranscriptionModel(
model_type=ModelType.HUGGING_FACE,
hugging_face_model_id="RaivisDejus/whisper-tiny-lv",
),
],
)
def test_download_model(self, model: TranscriptionModel):
model_loader = ModelDownloader(model=model)
model_loader.run()
model_path = model.get_local_model_path()
assert model_path is not None, "Model path is None"
assert os.path.isdir(model_path), "Model path is not a directory"
assert len(os.listdir(model_path)) > 0, "Model directory is empty"