mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 14:45:46 +01:00
Adding option to specify custom model root (#894)
This commit is contained in:
parent
f6fc65eeae
commit
4d9547d9c1
4 changed files with 10 additions and 11 deletions
|
|
@ -36,6 +36,7 @@ except ImportError:
|
|||
|
||||
model_root_dir = user_cache_dir("Buzz")
|
||||
model_root_dir = os.path.join(model_root_dir, "models")
|
||||
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
|
||||
os.makedirs(model_root_dir, exist_ok=True)
|
||||
|
||||
logging.debug("Model root directory: %s", model_root_dir)
|
||||
|
|
@ -270,7 +271,6 @@ class HuggingfaceDownloadMonitor:
|
|||
self.model_root = model_root
|
||||
self.progress = progress
|
||||
self.total_file_size = total_file_size
|
||||
self.tmp_download_root = None
|
||||
self.incomplete_download_root = None
|
||||
self.stop_event = threading.Event()
|
||||
self.monitor_thread = None
|
||||
|
|
@ -278,25 +278,20 @@ class HuggingfaceDownloadMonitor:
|
|||
|
||||
def set_download_roots(self):
|
||||
normalized_model_root = os.path.normpath(self.model_root)
|
||||
normalized_hub_path = os.path.normpath("/models/")
|
||||
index = normalized_model_root.find(normalized_hub_path)
|
||||
if index > 0:
|
||||
self.tmp_download_root = normalized_model_root[:index + len(normalized_hub_path)]
|
||||
|
||||
two_dirs_up = os.path.normpath(os.path.join(normalized_model_root, "..", ".."))
|
||||
self.incomplete_download_root = os.path.normpath(os.path.join(two_dirs_up, "blobs"))
|
||||
|
||||
def clean_tmp_files(self):
|
||||
for filename in os.listdir(self.tmp_download_root):
|
||||
for filename in os.listdir(model_root_dir):
|
||||
if filename.startswith("tmp"):
|
||||
os.remove(os.path.join(self.tmp_download_root, filename))
|
||||
os.remove(os.path.join(model_root_dir, filename))
|
||||
|
||||
def monitor_file_size(self):
|
||||
while not self.stop_event.is_set():
|
||||
if self.tmp_download_root is not None:
|
||||
for filename in os.listdir(self.tmp_download_root):
|
||||
if model_root_dir is not None:
|
||||
for filename in os.listdir(model_root_dir):
|
||||
if filename.startswith("tmp"):
|
||||
file_size = os.path.getsize(os.path.join(self.tmp_download_root, filename))
|
||||
file_size = os.path.getsize(os.path.join(model_root_dir, filename))
|
||||
self.progress.emit((file_size, self.total_file_size))
|
||||
|
||||
for filename in os.listdir(self.incomplete_download_root):
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class RecordingTranscriber(QObject):
|
|||
elif self.transcription_options.model.model_type == ModelType.FASTER_WHISPER:
|
||||
model_root_dir = user_cache_dir("Buzz")
|
||||
model_root_dir = os.path.join(model_root_dir, "models")
|
||||
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
|
||||
|
||||
device = "auto"
|
||||
if platform.system() == "Windows":
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
|
||||
model_root_dir = user_cache_dir("Buzz")
|
||||
model_root_dir = os.path.join(model_root_dir, "models")
|
||||
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
|
||||
|
||||
device = "auto"
|
||||
if platform.system() == "Windows":
|
||||
|
|
|
|||
|
|
@ -69,3 +69,5 @@ combined to produce the final answer.
|
|||
**BUZZ_TRANSLATION_API_BASE_URl** - Base URL of OpenAI compatible API to use for translation. Available from `v1.0.2`.
|
||||
|
||||
**BUZZ_TRANSLATION_API_KEY** - Api key of OpenAI compatible API to use for translation. Available from `v1.0.2`.
|
||||
|
||||
**BUZZ_MODEL_ROOT** - Root directory to store model files. Defaults to [user_cache_dir](https://pypi.org/project/platformdirs/). Available from `v1.0.2`.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue