diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 163d39ab..2094e5b4 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -36,6 +36,7 @@ except ImportError: model_root_dir = user_cache_dir("Buzz") model_root_dir = os.path.join(model_root_dir, "models") +model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir) os.makedirs(model_root_dir, exist_ok=True) logging.debug("Model root directory: %s", model_root_dir) @@ -270,7 +271,6 @@ class HuggingfaceDownloadMonitor: self.model_root = model_root self.progress = progress self.total_file_size = total_file_size - self.tmp_download_root = None self.incomplete_download_root = None self.stop_event = threading.Event() self.monitor_thread = None @@ -278,25 +278,20 @@ class HuggingfaceDownloadMonitor: def set_download_roots(self): normalized_model_root = os.path.normpath(self.model_root) - normalized_hub_path = os.path.normpath("/models/") - index = normalized_model_root.find(normalized_hub_path) - if index > 0: - self.tmp_download_root = normalized_model_root[:index + len(normalized_hub_path)] - two_dirs_up = os.path.normpath(os.path.join(normalized_model_root, "..", "..")) self.incomplete_download_root = os.path.normpath(os.path.join(two_dirs_up, "blobs")) def clean_tmp_files(self): - for filename in os.listdir(self.tmp_download_root): + for filename in os.listdir(model_root_dir): if filename.startswith("tmp"): - os.remove(os.path.join(self.tmp_download_root, filename)) + os.remove(os.path.join(model_root_dir, filename)) def monitor_file_size(self): while not self.stop_event.is_set(): - if self.tmp_download_root is not None: - for filename in os.listdir(self.tmp_download_root): + if model_root_dir is not None: + for filename in os.listdir(model_root_dir): if filename.startswith("tmp"): - file_size = os.path.getsize(os.path.join(self.tmp_download_root, filename)) + file_size = os.path.getsize(os.path.join(model_root_dir, filename)) self.progress.emit((file_size, self.total_file_size)) for filename in os.listdir(self.incomplete_download_root): diff --git a/buzz/transcriber/recording_transcriber.py b/buzz/transcriber/recording_transcriber.py index bfbaa86a..da58d7ee 100644 --- a/buzz/transcriber/recording_transcriber.py +++ b/buzz/transcriber/recording_transcriber.py @@ -73,6 +73,7 @@ class RecordingTranscriber(QObject): elif self.transcription_options.model.model_type == ModelType.FASTER_WHISPER: model_root_dir = user_cache_dir("Buzz") model_root_dir = os.path.join(model_root_dir, "models") + model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir) device = "auto" if platform.system() == "Windows": diff --git a/buzz/transcriber/whisper_file_transcriber.py b/buzz/transcriber/whisper_file_transcriber.py index 3585276d..f9bf1f58 100644 --- a/buzz/transcriber/whisper_file_transcriber.py +++ b/buzz/transcriber/whisper_file_transcriber.py @@ -143,6 +143,7 @@ class WhisperFileTranscriber(FileTranscriber): model_root_dir = user_cache_dir("Buzz") model_root_dir = os.path.join(model_root_dir, "models") + model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir) device = "auto" if platform.system() == "Windows": diff --git a/docs/docs/preferences.md b/docs/docs/preferences.md index d5405456..86e09875 100644 --- a/docs/docs/preferences.md +++ b/docs/docs/preferences.md @@ -69,3 +69,5 @@ combined to produce the final answer. **BUZZ_TRANSLATION_API_BASE_URl** - Base URL of OpenAI compatible API to use for translation. Available from `v1.0.2`. **BUZZ_TRANSLATION_API_KEY** - Api key of OpenAI compatible API to use for translation. Available from `v1.0.2`. + +**BUZZ_MODEL_ROOT** - Root directory to store model files. Defaults to [user_cache_dir](https://pypi.org/project/platformdirs/). Available from `v1.0.2`.