Download Whisper.cpp models from Hugging Face (#262)

This commit is contained in:
Chidi Williams 2022-12-21 20:44:07 +00:00 committed by GitHub
parent 44aeef95c2
commit 82bdd30fb8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 11 deletions

View file

@ -198,8 +198,8 @@ class DownloadModelProgressDialog(QProgressDialog):
start_time: datetime
def __init__(self, parent: Optional[QWidget], *args) -> None:
super().__init__('Downloading resources (0%, unknown time remaining)',
'Cancel', 0, 1_000_000, parent, *args)
super().__init__('Downloading model (0%, unknown time remaining)',
'Cancel', 0, 100, parent, *args)
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.start_time = datetime.now()
@ -211,7 +211,7 @@ class DownloadModelProgressDialog(QProgressDialog):
time_left = (time_spent / fraction_completed) - time_spent
self.setLabelText(
f'Downloading resources ({(fraction_completed):.2%}, {humanize.naturaldelta(time_left)} remaining)')
f'Downloading model ({fraction_completed :.0%}, {humanize.naturaldelta(time_left)} remaining)')
class RecordingTranscriberObject(QObject):
@ -479,7 +479,7 @@ class RecordingTranscriberWidget(QWidget):
self.setWindowTitle('Live Recording')
self.setFixedSize(400, 520)
self.transcription_options = TranscriptionOptions()
self.transcription_options = TranscriptionOptions(model=Model.WHISPER_CPP_TINY)
self.audio_devices_combo_box = AudioDevicesComboBox(self)
self.audio_devices_combo_box.device_changed.connect(

View file

@ -36,18 +36,18 @@ class ModelLoader(QObject):
try:
if self.use_whisper_cpp:
root = user_cache_dir('Buzz')
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{self.name}.bin'
url = f'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-{self.name}.bin'
model_path = os.path.join(root, f'ggml-model-whisper-{self.name}.bin')
else:
root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
url = whisper._MODELS[self.name]
model_path = os.path.join(root, os.path.basename(url))
os.makedirs(root, exist_ok=True)
model_path = os.path.join(root, os.path.basename(url))
if os.path.exists(model_path) and not os.path.isfile(model_path):
raise RuntimeError(
f"{model_path} exists and is not a regular file")

View file

@ -159,7 +159,8 @@ class RecordingTranscriber:
self.model_path) if self.use_whisper_cpp else whisper.load_model(self.model_path)
logging.debug(
'Recording, language = %s, task = %s, device = %s, sample rate = %s, model_path = %s, temperature = %s, initial prompt length = %s',
'Recording, language = %s, task = %s, device = %s, sample rate = %s, model_path = %s, temperature = %s, '
'initial prompt length = %s',
self.language, self.task, self.input_device_index, self.sample_rate, self.model_path, self.temperature,
len(self.initial_prompt))
self.current_stream = sounddevice.InputStream(

View file

@ -1,3 +1,4 @@
import logging
import os.path
import pathlib
from unittest.mock import Mock, patch
@ -126,7 +127,7 @@ class TestDownloadModelProgressDialog:
def test_should_show_dialog(self, qtbot: QtBot):
dialog = DownloadModelProgressDialog(parent=None)
qtbot.add_widget(dialog)
assert dialog.labelText() == 'Downloading resources (0%, unknown time remaining)'
assert dialog.labelText() == 'Downloading model (0%, unknown time remaining)'
def test_should_update_label_on_progress(self, qtbot: QtBot):
dialog = DownloadModelProgressDialog(parent=None)
@ -134,12 +135,13 @@ class TestDownloadModelProgressDialog:
dialog.set_fraction_completed(0.0)
dialog.set_fraction_completed(0.01)
logging.debug(dialog.labelText())
assert dialog.labelText().startswith(
'Downloading resources (1.00%')
'Downloading model (1%')
dialog.set_fraction_completed(0.1)
assert dialog.labelText().startswith(
'Downloading resources (10.00%')
'Downloading model (10%')
# Other windows should not be processing while models are being downloaded
def test_should_be_an_application_modal(self, qtbot: QtBot):