diff --git a/buzz/model_loader.py b/buzz/model_loader.py index e901cbdd..e8e82753 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -2,18 +2,19 @@ import enum import hashlib import logging import os +import time +import threading import shutil import subprocess import sys import tempfile import warnings from dataclasses import dataclass -from typing import Optional +from typing import Optional, List import requests from PyQt6.QtCore import QObject, pyqtSignal, QRunnable from platformdirs import user_cache_dir -from tqdm.auto import tqdm import faster_whisper import whisper @@ -86,11 +87,11 @@ class ModelType(enum.Enum): HUGGING_FACE_MODEL_ALLOW_PATTERNS = [ + "model.safetensors", # largest by size first "added_tokens.json", "config.json", "generation_config.json", "merges.txt", - "model.safetensors", "normalizer.json", "preprocessor_config.json", "special_tokens_map.json", @@ -198,10 +199,6 @@ WHISPER_CPP_MODELS_SHA256 = { } -def get_hugging_face_file_url(author: str, repository_name: str, filename: str): - return f"https://huggingface.co/{author}/{repository_name}/resolve/bf8b606c2fcd9173605cdf6bd2ac8a75a8141b6c/{filename}" - - def get_whisper_cpp_file_path(size: WhisperModelSize) -> str: root_dir = user_cache_dir("Buzz") return os.path.join(root_dir, f"ggml-model-whisper-{size.value}.bin") @@ -215,8 +212,88 @@ def get_whisper_file_path(size: WhisperModelSize) -> str: return os.path.join(root_dir, os.path.basename(url)) +class HuggingfaceDownloadMonitor: + def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size: int): + self.model_root = model_root + self.progress = progress + self.total_file_size = total_file_size + self.tmp_download_root = self.get_tmp_download_root(model_root) + self.stop_event = threading.Event() + self.monitor_thread = None + + @staticmethod + def get_tmp_download_root(model_root): + normalized_model_root = os.path.normpath(model_root) + normalized_hub_path = os.path.normpath("huggingface/hub/") + index = normalized_model_root.find(normalized_hub_path) + if index == -1: + raise ValueError(f"Invalid model_root, '{normalized_hub_path}' not found") + return normalized_model_root[:index + len(normalized_hub_path)] + + def clean_tmp_files(self): + for filename in os.listdir(self.tmp_download_root): + if filename.startswith("tmp"): + os.remove(os.path.join(self.tmp_download_root, filename)) + + def monitor_file_size(self): + while not self.stop_event.is_set(): + for filename in os.listdir(self.tmp_download_root): + if filename.startswith("tmp"): + file_size = os.path.getsize(os.path.join(self.tmp_download_root, filename)) + self.progress.emit((file_size, self.total_file_size)) + time.sleep(2) + + def start_monitoring(self): + self.clean_tmp_files() + self.monitor_thread = threading.Thread(target=self.monitor_file_size) + self.monitor_thread.start() + + def stop_monitoring(self): + self.progress.emit((self.total_file_size, self.total_file_size)) + + if self.monitor_thread is not None: + self.stop_event.set() + self.monitor_thread.join() + + +def get_file_size(url): + response = requests.head(url, allow_redirects=True) + response.raise_for_status() + return int(response.headers['Content-Length']) + + +def download_from_huggingface( + repo_id: str, + allow_patterns: List[str], + progress: pyqtSignal(tuple), +): + progress.emit((1, 100)) + + model_root = huggingface_hub.snapshot_download( + repo_id, + allow_patterns=allow_patterns[1:], # all, but largest + ) + + progress.emit((1, 100)) + + largest_file_url = huggingface_hub.hf_hub_url(repo_id, allow_patterns[0]) + total_file_size = get_file_size(largest_file_url) + + model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, total_file_size) + model_download_monitor.start_monitoring() + + huggingface_hub.snapshot_download( + repo_id, + allow_patterns=allow_patterns[:1], # largest + ) + + model_download_monitor.stop_monitoring() + + return model_root + + def download_faster_whisper_model( - size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None + size: str, local_files_only=False, progress: pyqtSignal(tuple) = None ): if size not in faster_whisper.utils._MODELS: raise ValueError( @@ -227,17 +304,23 @@ def download_faster_whisper_model( repo_id = "guillaumekln/faster-whisper-%s" % size allow_patterns = [ + "model.bin", # largest by size first "config.json", - "model.bin", "tokenizer.json", "vocabulary.txt", ] - return huggingface_hub.snapshot_download( + if local_files_only: + return huggingface_hub.snapshot_download( + repo_id, + allow_patterns=allow_patterns, + local_files_only=True, + ) + + return download_from_huggingface( repo_id, allow_patterns=allow_patterns, - local_files_only=local_files_only, - tqdm_class=tqdm_class, + progress=progress, ) @@ -257,9 +340,8 @@ class ModelDownloader(QRunnable): def run(self) -> None: if self.model.model_type == ModelType.WHISPER_CPP: model_name = self.model.whisper_model_size.value - url = get_hugging_face_file_url( - author="ggerganov", - repository_name="whisper.cpp", + url = huggingface_hub.hf_hub_url( + repo_id="ggerganov/whisper.cpp", filename=f"ggml-{model_name}.bin", ) file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size) @@ -276,31 +358,19 @@ class ModelDownloader(QRunnable): url=url, file_path=file_path, expected_sha256=expected_sha256 ) - progress = self.signals.progress - - # gross abuse of power... - class _tqdm(tqdm): - def update(self, n: float | None = ...) -> bool | None: - progress.emit((n, self.total)) - return super().update(n) - - def close(self): - progress.emit((self.n, self.total)) - return super().close() - if self.model.model_type == ModelType.FASTER_WHISPER: model_path = download_faster_whisper_model( size=self.model.whisper_model_size.to_faster_whisper_model_size(), - tqdm_class=_tqdm, + progress=self.signals.progress, ) self.signals.finished.emit(model_path) return if self.model.model_type == ModelType.HUGGING_FACE: - model_path = huggingface_hub.snapshot_download( + model_path = download_from_huggingface( self.model.hugging_face_model_id, allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS, - tqdm_class=_tqdm + progress=self.signals.progress, ) self.signals.finished.emit(model_path) return diff --git a/tests/model_loader_test.py b/tests/model_loader_test.py new file mode 100644 index 00000000..d1c9d638 --- /dev/null +++ b/tests/model_loader_test.py @@ -0,0 +1,25 @@ +import os +import pytest + +from buzz.model_loader import ModelDownloader,TranscriptionModel, ModelType, WhisperModelSize + + +class TestModelLoader: + @pytest.mark.parametrize( + "model", + [ + TranscriptionModel( + model_type=ModelType.HUGGING_FACE, + hugging_face_model_id="RaivisDejus/whisper-tiny-lv", + ), + ], + ) + def test_download_model(self, model: TranscriptionModel): + model_loader = ModelDownloader(model=model) + model_loader.run() + + model_path = model.get_local_model_path() + + assert model_path is not None, "Model path is None" + assert os.path.isdir(model_path), "Model path is not a directory" + assert len(os.listdir(model_path)) > 0, "Model directory is empty"