diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 08d15730..6a9aed69 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -100,6 +100,10 @@ class ModelType(enum.Enum): HUGGING_FACE_MODEL_ALLOW_PATTERNS = [ "model.safetensors", # largest by size first + "pytorch_model.bin", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "model.safetensors.index.json", "added_tokens.json", "config.json", "generation_config.json", @@ -218,7 +222,8 @@ class TranscriptionModel: self.hugging_face_model_id, allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS, local_files_only=True, - cache_dir=model_root_dir + cache_dir=model_root_dir, + etag_timeout=60 ) except (ValueError, FileNotFoundError): return None @@ -324,7 +329,8 @@ def download_from_huggingface( model_root = huggingface_hub.snapshot_download( repo_id, allow_patterns=allow_patterns[num_large_files:], # all, but largest - cache_dir=model_root_dir + cache_dir=model_root_dir, + etag_timeout=60 ) except Exception as exc: logging.exception(exc) @@ -351,7 +357,8 @@ def download_from_huggingface( huggingface_hub.snapshot_download( repo_id, allow_patterns=allow_patterns[:num_large_files], # largest - cache_dir=model_root_dir + cache_dir=model_root_dir, + etag_timeout=60 ) except Exception as exc: logging.exception(exc) @@ -399,7 +406,8 @@ def download_faster_whisper_model( repo_id, allow_patterns=allow_patterns, local_files_only=True, - cache_dir=model_root_dir + cache_dir=model_root_dir, + etag_timeout=60 ) return download_from_huggingface( @@ -469,6 +477,7 @@ class ModelDownloader(QRunnable): self.model.hugging_face_model_id, allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS, progress=self.signals.progress, + num_large_files=4 ) self.signals.finished.emit(model_path) return diff --git a/buzz/transformers_whisper.py b/buzz/transformers_whisper.py index 270c8440..80d1fe12 100644 --- a/buzz/transformers_whisper.py +++ b/buzz/transformers_whisper.py @@ -1,7 +1,7 @@ -from typing import Optional, Union - +import os import numpy as np import torch +from typing import Optional, Union from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline @@ -17,12 +17,14 @@ class TransformersWhisper: language: str, task: str, ): - device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + safetensors_path = os.path.join(self.model_id, "model.safetensors") + use_safetensors = os.path.exists(safetensors_path) + model = AutoModelForSpeechSeq2Seq.from_pretrained( - self.model_id, torch_dtype=torch_dtype, use_safetensors=True + self.model_id, torch_dtype=torch_dtype, use_safetensors=use_safetensors ) model.generation_config.language = language diff --git a/buzz/widgets/model_download_progress_dialog.py b/buzz/widgets/model_download_progress_dialog.py index ca5b9a8b..e2a41dce 100644 --- a/buzz/widgets/model_download_progress_dialog.py +++ b/buzz/widgets/model_download_progress_dialog.py @@ -35,11 +35,11 @@ class ModelDownloadProgressDialog(QProgressDialog): def update_label_text(self, fraction_completed: float): downloading_text = _("Downloading model") remaining_text = _("remaining") - label_text = f"{downloading_text} ({fraction_completed:.0%}" + label_text = f"{downloading_text} (" if fraction_completed > 0: time_spent = (datetime.now() - self.start_time).total_seconds() time_left = (time_spent / fraction_completed) - time_spent - label_text += f", {humanize.naturaldelta(time_left)} {remaining_text}" + label_text += f"{humanize.naturaldelta(time_left)} {remaining_text}" label_text += ")" self.setLabelText(label_text)