Adding option to specify custom model root (#894)

This commit is contained in:
Raivis Dejus 2024-08-27 20:43:37 +03:00 committed by GitHub
commit 4d9547d9c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 10 additions and 11 deletions

View file

@ -36,6 +36,7 @@ except ImportError:
model_root_dir = user_cache_dir("Buzz")
model_root_dir = os.path.join(model_root_dir, "models")
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
os.makedirs(model_root_dir, exist_ok=True)
logging.debug("Model root directory: %s", model_root_dir)
@ -270,7 +271,6 @@ class HuggingfaceDownloadMonitor:
self.model_root = model_root
self.progress = progress
self.total_file_size = total_file_size
self.tmp_download_root = None
self.incomplete_download_root = None
self.stop_event = threading.Event()
self.monitor_thread = None
@ -278,25 +278,20 @@ class HuggingfaceDownloadMonitor:
def set_download_roots(self):
normalized_model_root = os.path.normpath(self.model_root)
normalized_hub_path = os.path.normpath("/models/")
index = normalized_model_root.find(normalized_hub_path)
if index > 0:
self.tmp_download_root = normalized_model_root[:index + len(normalized_hub_path)]
two_dirs_up = os.path.normpath(os.path.join(normalized_model_root, "..", ".."))
self.incomplete_download_root = os.path.normpath(os.path.join(two_dirs_up, "blobs"))
def clean_tmp_files(self):
for filename in os.listdir(self.tmp_download_root):
for filename in os.listdir(model_root_dir):
if filename.startswith("tmp"):
os.remove(os.path.join(self.tmp_download_root, filename))
os.remove(os.path.join(model_root_dir, filename))
def monitor_file_size(self):
while not self.stop_event.is_set():
if self.tmp_download_root is not None:
for filename in os.listdir(self.tmp_download_root):
if model_root_dir is not None:
for filename in os.listdir(model_root_dir):
if filename.startswith("tmp"):
file_size = os.path.getsize(os.path.join(self.tmp_download_root, filename))
file_size = os.path.getsize(os.path.join(model_root_dir, filename))
self.progress.emit((file_size, self.total_file_size))
for filename in os.listdir(self.incomplete_download_root):

View file

@ -73,6 +73,7 @@ class RecordingTranscriber(QObject):
elif self.transcription_options.model.model_type == ModelType.FASTER_WHISPER:
model_root_dir = user_cache_dir("Buzz")
model_root_dir = os.path.join(model_root_dir, "models")
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
device = "auto"
if platform.system() == "Windows":

View file

@ -143,6 +143,7 @@ class WhisperFileTranscriber(FileTranscriber):
model_root_dir = user_cache_dir("Buzz")
model_root_dir = os.path.join(model_root_dir, "models")
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
device = "auto"
if platform.system() == "Windows":

View file

@ -69,3 +69,5 @@ combined to produce the final answer.
**BUZZ_TRANSLATION_API_BASE_URl** - Base URL of OpenAI compatible API to use for translation. Available from `v1.0.2`.
**BUZZ_TRANSLATION_API_KEY** - Api key of OpenAI compatible API to use for translation. Available from `v1.0.2`.
**BUZZ_MODEL_ROOT** - Root directory to store model files. Defaults to [user_cache_dir](https://pypi.org/project/platformdirs/). Available from `v1.0.2`.