862 fix for multi part model lownload (#865)

This commit is contained in:
Raivis Dejus 2024-08-02 09:54:13 +03:00 committed by GitHub
commit ebb7cde23a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 21 additions and 10 deletions

View file

@ -100,6 +100,10 @@ class ModelType(enum.Enum):
HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
"model.safetensors", # largest by size first
"pytorch_model.bin",
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
"added_tokens.json",
"config.json",
"generation_config.json",
@ -218,7 +222,8 @@ class TranscriptionModel:
self.hugging_face_model_id,
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
local_files_only=True,
cache_dir=model_root_dir
cache_dir=model_root_dir,
etag_timeout=60
)
except (ValueError, FileNotFoundError):
return None
@ -324,7 +329,8 @@ def download_from_huggingface(
model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[num_large_files:], # all, but largest
cache_dir=model_root_dir
cache_dir=model_root_dir,
etag_timeout=60
)
except Exception as exc:
logging.exception(exc)
@ -351,7 +357,8 @@ def download_from_huggingface(
huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:num_large_files], # largest
cache_dir=model_root_dir
cache_dir=model_root_dir,
etag_timeout=60
)
except Exception as exc:
logging.exception(exc)
@ -399,7 +406,8 @@ def download_faster_whisper_model(
repo_id,
allow_patterns=allow_patterns,
local_files_only=True,
cache_dir=model_root_dir
cache_dir=model_root_dir,
etag_timeout=60
)
return download_from_huggingface(
@ -469,6 +477,7 @@ class ModelDownloader(QRunnable):
self.model.hugging_face_model_id,
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
progress=self.signals.progress,
num_large_files=4
)
self.signals.finished.emit(model_path)
return

View file

@ -1,7 +1,7 @@
from typing import Optional, Union
import os
import numpy as np
import torch
from typing import Optional, Union
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
@ -17,12 +17,14 @@ class TransformersWhisper:
language: str,
task: str,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
safetensors_path = os.path.join(self.model_id, "model.safetensors")
use_safetensors = os.path.exists(safetensors_path)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.model_id, torch_dtype=torch_dtype, use_safetensors=True
self.model_id, torch_dtype=torch_dtype, use_safetensors=use_safetensors
)
model.generation_config.language = language

View file

@ -35,11 +35,11 @@ class ModelDownloadProgressDialog(QProgressDialog):
def update_label_text(self, fraction_completed: float):
downloading_text = _("Downloading model")
remaining_text = _("remaining")
label_text = f"{downloading_text} ({fraction_completed:.0%}"
label_text = f"{downloading_text} ("
if fraction_completed > 0:
time_spent = (datetime.now() - self.start_time).total_seconds()
time_left = (time_spent / fraction_completed) - time_spent
label_text += f", {humanize.naturaldelta(time_left)} {remaining_text}"
label_text += f"{humanize.naturaldelta(time_left)} {remaining_text}"
label_text += ")"
self.setLabelText(label_text)