mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 22:55:46 +01:00
862 fix for multi part model lownload (#865)
This commit is contained in:
parent
6c3959d0ff
commit
ebb7cde23a
3 changed files with 21 additions and 10 deletions
|
|
@ -100,6 +100,10 @@ class ModelType(enum.Enum):
|
|||
|
||||
HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
|
||||
"model.safetensors", # largest by size first
|
||||
"pytorch_model.bin",
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
"model.safetensors.index.json",
|
||||
"added_tokens.json",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
|
|
@ -218,7 +222,8 @@ class TranscriptionModel:
|
|||
self.hugging_face_model_id,
|
||||
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
|
||||
local_files_only=True,
|
||||
cache_dir=model_root_dir
|
||||
cache_dir=model_root_dir,
|
||||
etag_timeout=60
|
||||
)
|
||||
except (ValueError, FileNotFoundError):
|
||||
return None
|
||||
|
|
@ -324,7 +329,8 @@ def download_from_huggingface(
|
|||
model_root = huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns[num_large_files:], # all, but largest
|
||||
cache_dir=model_root_dir
|
||||
cache_dir=model_root_dir,
|
||||
etag_timeout=60
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
|
|
@ -351,7 +357,8 @@ def download_from_huggingface(
|
|||
huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns[:num_large_files], # largest
|
||||
cache_dir=model_root_dir
|
||||
cache_dir=model_root_dir,
|
||||
etag_timeout=60
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
|
|
@ -399,7 +406,8 @@ def download_faster_whisper_model(
|
|||
repo_id,
|
||||
allow_patterns=allow_patterns,
|
||||
local_files_only=True,
|
||||
cache_dir=model_root_dir
|
||||
cache_dir=model_root_dir,
|
||||
etag_timeout=60
|
||||
)
|
||||
|
||||
return download_from_huggingface(
|
||||
|
|
@ -469,6 +477,7 @@ class ModelDownloader(QRunnable):
|
|||
self.model.hugging_face_model_id,
|
||||
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
|
||||
progress=self.signals.progress,
|
||||
num_large_files=4
|
||||
)
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
|
||||
|
||||
|
|
@ -17,12 +17,14 @@ class TransformersWhisper:
|
|||
language: str,
|
||||
task: str,
|
||||
):
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
safetensors_path = os.path.join(self.model_id, "model.safetensors")
|
||||
use_safetensors = os.path.exists(safetensors_path)
|
||||
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.model_id, torch_dtype=torch_dtype, use_safetensors=True
|
||||
self.model_id, torch_dtype=torch_dtype, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
model.generation_config.language = language
|
||||
|
|
|
|||
|
|
@ -35,11 +35,11 @@ class ModelDownloadProgressDialog(QProgressDialog):
|
|||
def update_label_text(self, fraction_completed: float):
|
||||
downloading_text = _("Downloading model")
|
||||
remaining_text = _("remaining")
|
||||
label_text = f"{downloading_text} ({fraction_completed:.0%}"
|
||||
label_text = f"{downloading_text} ("
|
||||
if fraction_completed > 0:
|
||||
time_spent = (datetime.now() - self.start_time).total_seconds()
|
||||
time_left = (time_spent / fraction_completed) - time_spent
|
||||
label_text += f", {humanize.naturaldelta(time_left)} {remaining_text}"
|
||||
label_text += f"{humanize.naturaldelta(time_left)} {remaining_text}"
|
||||
label_text += ")"
|
||||
|
||||
self.setLabelText(label_text)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue