mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-29 05:00:21 +02:00
Download Whisper.cpp models from Hugging Face (#262)
This commit is contained in:
parent
44aeef95c2
commit
82bdd30fb8
|
@ -198,8 +198,8 @@ class DownloadModelProgressDialog(QProgressDialog):
|
|||
start_time: datetime
|
||||
|
||||
def __init__(self, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__('Downloading resources (0%, unknown time remaining)',
|
||||
'Cancel', 0, 1_000_000, parent, *args)
|
||||
super().__init__('Downloading model (0%, unknown time remaining)',
|
||||
'Cancel', 0, 100, parent, *args)
|
||||
self.setWindowModality(Qt.WindowModality.ApplicationModal)
|
||||
self.start_time = datetime.now()
|
||||
|
||||
|
@ -211,7 +211,7 @@ class DownloadModelProgressDialog(QProgressDialog):
|
|||
time_left = (time_spent / fraction_completed) - time_spent
|
||||
|
||||
self.setLabelText(
|
||||
f'Downloading resources ({(fraction_completed):.2%}, {humanize.naturaldelta(time_left)} remaining)')
|
||||
f'Downloading model ({fraction_completed :.0%}, {humanize.naturaldelta(time_left)} remaining)')
|
||||
|
||||
|
||||
class RecordingTranscriberObject(QObject):
|
||||
|
@ -479,7 +479,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.setWindowTitle('Live Recording')
|
||||
self.setFixedSize(400, 520)
|
||||
|
||||
self.transcription_options = TranscriptionOptions()
|
||||
self.transcription_options = TranscriptionOptions(model=Model.WHISPER_CPP_TINY)
|
||||
|
||||
self.audio_devices_combo_box = AudioDevicesComboBox(self)
|
||||
self.audio_devices_combo_box.device_changed.connect(
|
||||
|
|
|
@ -36,18 +36,18 @@ class ModelLoader(QObject):
|
|||
try:
|
||||
if self.use_whisper_cpp:
|
||||
root = user_cache_dir('Buzz')
|
||||
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{self.name}.bin'
|
||||
url = f'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-{self.name}.bin'
|
||||
model_path = os.path.join(root, f'ggml-model-whisper-{self.name}.bin')
|
||||
else:
|
||||
root = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
url = whisper._MODELS[self.name]
|
||||
model_path = os.path.join(root, os.path.basename(url))
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
model_path = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(model_path) and not os.path.isfile(model_path):
|
||||
raise RuntimeError(
|
||||
f"{model_path} exists and is not a regular file")
|
||||
|
|
|
@ -159,7 +159,8 @@ class RecordingTranscriber:
|
|||
self.model_path) if self.use_whisper_cpp else whisper.load_model(self.model_path)
|
||||
|
||||
logging.debug(
|
||||
'Recording, language = %s, task = %s, device = %s, sample rate = %s, model_path = %s, temperature = %s, initial prompt length = %s',
|
||||
'Recording, language = %s, task = %s, device = %s, sample rate = %s, model_path = %s, temperature = %s, '
|
||||
'initial prompt length = %s',
|
||||
self.language, self.task, self.input_device_index, self.sample_rate, self.model_path, self.temperature,
|
||||
len(self.initial_prompt))
|
||||
self.current_stream = sounddevice.InputStream(
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os.path
|
||||
import pathlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
@ -126,7 +127,7 @@ class TestDownloadModelProgressDialog:
|
|||
def test_should_show_dialog(self, qtbot: QtBot):
|
||||
dialog = DownloadModelProgressDialog(parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
assert dialog.labelText() == 'Downloading resources (0%, unknown time remaining)'
|
||||
assert dialog.labelText() == 'Downloading model (0%, unknown time remaining)'
|
||||
|
||||
def test_should_update_label_on_progress(self, qtbot: QtBot):
|
||||
dialog = DownloadModelProgressDialog(parent=None)
|
||||
|
@ -134,12 +135,13 @@ class TestDownloadModelProgressDialog:
|
|||
dialog.set_fraction_completed(0.0)
|
||||
|
||||
dialog.set_fraction_completed(0.01)
|
||||
logging.debug(dialog.labelText())
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading resources (1.00%')
|
||||
'Downloading model (1%')
|
||||
|
||||
dialog.set_fraction_completed(0.1)
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading resources (10.00%')
|
||||
'Downloading model (10%')
|
||||
|
||||
# Other windows should not be processing while models are being downloaded
|
||||
def test_should_be_an_application_modal(self, qtbot: QtBot):
|
||||
|
|
Loading…
Reference in a new issue