Refactor model loading (#191)

This commit is contained in:
Chidi Williams 2022-11-28 13:10:39 +00:00 committed by GitHub
parent 3486056dea
commit 219f20d0e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 235 additions and 243 deletions

View file

@ -47,7 +47,7 @@ clean:
rm -rf dist/* || true
test: whisper_cpp.py
pytest --cov
pytest --cov --cov-report=html
dist/Buzz dist/Buzz.app: whisper_cpp.py
pyinstaller --noconfirm Buzz.spec

144
gui.py
View file

@ -10,7 +10,7 @@ import humanize
import sounddevice
from PyQt6 import QtGui
from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QTimer,
QUrl, pyqtSignal)
QUrl, pyqtSignal, QThreadPool)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor)
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
@ -23,6 +23,7 @@ from whisper import tokenizer
from __version__ import VERSION
from transcriber import FileTranscriber, OutputFormat, RecordingTranscriber
from whispr import LOADED_WHISPER_DLL, Task
from model_loader import ModelLoader
APP_NAME = 'Buzz'
@ -271,13 +272,13 @@ class FileTranscriberObject(QObject):
transcriber: FileTranscriber
def __init__(
self, model_name: str, use_whisper_cpp: bool, language: Optional[str],
self, model_path: str, use_whisper_cpp: bool, language: Optional[str],
task: Task, file_path: str, output_file_path: str,
output_format: OutputFormat, word_level_timings: bool,
parent: Optional['QObject'], *args) -> None:
super().__init__(parent, *args)
self.transcriber = FileTranscriber(
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
on_download_model_chunk=self.on_download_model_progress,
language=language, task=task, file_path=file_path,
output_file_path=output_file_path, output_format=output_format,
@ -299,9 +300,6 @@ class FileTranscriberObject(QObject):
def stop(self):
self.transcriber.stop()
def stop_loading_model(self):
self.transcriber.stop_loading_model()
class RecordingTranscriberObject(QObject):
"""
@ -313,11 +311,11 @@ class RecordingTranscriberObject(QObject):
download_model_progress = pyqtSignal(tuple)
transcriber: RecordingTranscriber
def __init__(self, model_name, use_whisper_cpp, language: Optional[str],
def __init__(self, model_path: str, use_whisper_cpp, language: Optional[str],
task: Task, input_device_index: Optional[int], parent: Optional[QWidget], *args) -> None:
super().__init__(parent, *args)
self.transcriber = RecordingTranscriber(
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
on_download_model_chunk=self.on_download_model_progress, language=language,
event_callback=self.event_callback, task=task,
input_device_index=input_device_index)
@ -334,9 +332,6 @@ class RecordingTranscriberObject(QObject):
def stop_recording(self):
self.transcriber.stop_recording()
def stop_loading_model(self):
self.transcriber.stop_loading_model()
class TimerLabel(QLabel):
start_time: Optional[QDateTime]
@ -378,6 +373,11 @@ def get_model_name(quality: Quality) -> str:
}[quality][0]
def show_model_download_error_dialog(parent: QWidget, error: str):
message = f'Unable to load the Whisper model: {error}. Please retry or check the application logs for more information.'
QMessageBox.critical(parent, '', message)
class FileTranscriberWidget(QWidget):
selected_quality = Quality.VERY_LOW
selected_language: Optional[str] = None
@ -387,6 +387,7 @@ class FileTranscriberWidget(QWidget):
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None
file_transcriber: Optional[FileTranscriberObject] = None
model_loader: Optional[ModelLoader] = None
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
super().__init__(parent)
@ -445,6 +446,7 @@ class FileTranscriberWidget(QWidget):
layout.addWidget(widget, row_index, col_offset, 1, col_width)
self.setLayout(layout)
self.pool = QThreadPool()
def on_quality_changed(self, quality: Quality):
self.selected_quality = quality
@ -476,26 +478,37 @@ class FileTranscriberWidget(QWidget):
self.run_button.setDisabled(True)
model_name = get_model_name(self.selected_quality)
self.file_transcriber = FileTranscriberObject(
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
file_path=self.file_path,
language=self.selected_language, task=self.selected_task,
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
parent=self)
self.file_transcriber.download_model_progress.connect(
self.on_download_model_progress)
self.file_transcriber.event_received.connect(
self.on_transcriber_event)
def start_file_transcription(model_path: str):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog = None
self.file_transcriber.start()
self.file_transcriber = FileTranscriberObject(
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
file_path=self.file_path,
language=self.selected_language, task=self.selected_task,
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
parent=self)
self.file_transcriber.event_received.connect(
self.on_transcriber_event)
self.file_transcriber.start()
self.model_loader = ModelLoader(
name=model_name, use_whisper_cpp=use_whisper_cpp)
self.model_loader.signals.progress.connect(
self.on_download_model_progress)
self.model_loader.signals.completed.connect(start_file_transcription)
self.model_loader.signals.error.connect(self.on_download_model_error)
self.pool.start(self.model_loader)
def on_download_model_progress(self, progress: Tuple[int, int]):
(current_size, _) = progress
(current_size, total_size) = progress
if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = DownloadModelProgressDialog(
total_size=100, parent=self)
total_size=total_size, parent=self)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)
@ -503,10 +516,12 @@ class FileTranscriberWidget(QWidget):
self.model_download_progress_dialog.setValue(
current_size=current_size)
def on_download_model_error(self, error: str):
show_model_download_error_dialog(self, error)
self.reset_transcription()
def on_transcriber_event(self, event: FileTranscriber.Event):
if isinstance(event, FileTranscriber.LoadedModelEvent):
self.reset_model_download()
elif isinstance(event, FileTranscriber.ProgressEvent):
if isinstance(event, FileTranscriber.ProgressEvent):
current_size = event.current_value
total_size = event.max_value
@ -535,8 +550,8 @@ class FileTranscriberWidget(QWidget):
self.transcriber_progress_dialog = None
def on_cancel_model_progress_dialog(self):
if self.file_transcriber is not None:
self.file_transcriber.stop_loading_model()
if self.model_loader is not None:
self.model_loader.stop()
self.reset_model_download()
def reset_model_download(self):
@ -583,6 +598,7 @@ class RecordingTranscriberWidget(QWidget):
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
settings: Settings
transcriber: Optional[RecordingTranscriberObject] = None
model_loader: Optional[ModelLoader] = None
def __init__(self, parent: Optional[QWidget]) -> None:
super().__init__(parent)
@ -637,6 +653,8 @@ class RecordingTranscriberWidget(QWidget):
self.setLayout(layout)
self.pool = QThreadPool()
def on_device_changed(self, device_id: int):
self.selected_device_id = device_id
@ -663,32 +681,36 @@ class RecordingTranscriberWidget(QWidget):
model_name = get_model_name(self.selected_quality)
self.transcriber = RecordingTranscriberObject(
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
language=self.selected_language, task=self.selected_task,
input_device_index=self.selected_device_id,
parent=self
)
self.transcriber.event_changed.connect(
self.on_transcriber_event_changed)
self.transcriber.download_model_progress.connect(
self.on_download_model_progress)
self.transcriber.start_recording()
def on_transcriber_event_changed(self, event: RecordingTranscriber.Event):
if isinstance(event, RecordingTranscriber.LoadedModelEvent):
def start_recording_transcription(model_path: str):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.timer_label.start_timer()
self.record_button.setDisabled(False)
self.reset_model_download()
elif isinstance(event, RecordingTranscriber.TranscribedNextChunkEvent):
text = event.text.strip()
if len(text) > 0:
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
self.text_box.insertPlainText(text + '\n\n')
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog = None
self.transcriber = RecordingTranscriberObject(
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
language=self.selected_language, task=self.selected_task,
input_device_index=self.selected_device_id,
parent=self
)
self.transcriber.event_changed.connect(
self.on_transcriber_event_changed)
self.transcriber.download_model_progress.connect(
self.on_download_model_progress)
self.transcriber.start_recording()
self.model_loader = ModelLoader(
name=model_name, use_whisper_cpp=use_whisper_cpp)
self.model_loader.signals.progress.connect(
self.on_download_model_progress)
self.model_loader.signals.completed.connect(
start_recording_transcription)
self.model_loader.signals.error.connect(self.on_download_model_error)
self.pool.start(self.model_loader)
def on_download_model_progress(self, progress: Tuple[int, int]):
(current_size, _) = progress
@ -703,14 +725,26 @@ class RecordingTranscriberWidget(QWidget):
self.model_download_progress_dialog.setValue(
current_size=current_size)
def on_download_model_error(self, error: str):
show_model_download_error_dialog(self, error)
self.stop_recording()
def on_transcriber_event_changed(self, event: RecordingTranscriber.Event):
if isinstance(event, RecordingTranscriber.TranscribedNextChunkEvent):
text = event.text.strip()
if len(text) > 0:
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
self.text_box.insertPlainText(text + '\n\n')
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
def stop_recording(self):
if self.transcriber is not None:
self.transcriber.stop_recording()
self.timer_label.stop_timer()
def on_cancel_model_progress_dialog(self):
if self.transcriber is not None:
self.transcriber.stop_loading_model()
if self.model_loader is not None:
self.model_loader.stop()
self.reset_model_download()
self.record_button.force_stop()
self.record_button.setDisabled(False)
@ -736,7 +770,7 @@ class AppIcon(QIcon):
class AboutDialog(QDialog):
def __init__(self, parent: Optional[QWidget]) -> None:
def __init__(self, parent: Optional[QWidget]=None) -> None:
super().__init__(parent)
self.setFixedSize(200, 200)

View file

@ -7,7 +7,7 @@ from PyQt6.QtCore import Qt
from gui import (Application, AudioDevicesComboBox,
DownloadModelProgressDialog, FileTranscriberWidget,
LanguagesComboBox, MainWindow, OutputFormatsComboBox, Quality, Settings,
LanguagesComboBox, MainWindow, OutputFormatsComboBox, Quality, Settings, AboutDialog,
QualityComboBox, TranscriberProgressDialog)
from transcriber import OutputFormat
@ -204,3 +204,9 @@ class TestSettings:
settings.set_enable_ggml_inference(False)
assert settings.get_enable_ggml_inference() is False
class TestAboutDialog:
def test_should_create(self):
dialog = AboutDialog()
assert dialog is not None

96
model_loader.py Normal file
View file

@ -0,0 +1,96 @@
import hashlib
import logging
import os
import warnings
import requests
import whisper
from platformdirs import user_cache_dir
from PyQt6.QtCore import QObject, QRunnable, pyqtSignal, pyqtSlot
MODELS_SHA256 = {
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
}
class ModelLoader(QRunnable):
class Signals(QObject):
progress = pyqtSignal(tuple) # (current, total)
completed = pyqtSignal(str)
error = pyqtSignal(str)
signals: Signals
stopped = False
def __init__(self, name: str, use_whisper_cpp=False) -> None:
super(ModelLoader, self).__init__()
self.name = name
self.use_whisper_cpp = use_whisper_cpp
self.signals = self.Signals()
@pyqtSlot()
def run(self):
try:
if self.use_whisper_cpp:
root = user_cache_dir('Buzz')
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{self.name}.bin'
else:
root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
url = whisper._MODELS[self.name]
os.makedirs(root, exist_ok=True)
model_path = os.path.join(root, os.path.basename(url))
if os.path.exists(model_path) and not os.path.isfile(model_path):
raise RuntimeError(
f"{model_path} exists and is not a regular file")
expected_sha256 = MODELS_SHA256[self.name] if self.use_whisper_cpp else url.split(
"/")[-2]
if os.path.isfile(model_path):
model_bytes = open(model_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
self.signals.completed.emit(model_path)
return
else:
warnings.warn(
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
source.raise_for_status()
total_size = int(source.headers.get('Content-Length', 0))
current = 0
self.signals.progress.emit((0, total_size))
for chunk in source.iter_content(chunk_size=8192):
if self.stopped:
return
output.write(chunk)
current += len(chunk)
self.signals.progress.emit((current, total_size))
model_bytes = open(model_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.")
self.signals.completed.emit(model_path)
except RuntimeError as exc:
self.signals.error.emit(str(exc))
logging.exception('')
except requests.RequestException:
self.signals.error.emit('A connection error occurred.')
logging.exception('')
except Exception:
self.signals.error.emit('An unknown error occurred.')
logging.exception('')
def stop(self):
self.stopped = True

View file

@ -19,8 +19,7 @@ import whisper
from sounddevice import PortAudioError
from conn import pipe_stderr, pipe_stdout
from whispr import (ModelLoader, Segment, Stopped, Task, WhisperCpp,
read_progress, whisper_cpp_params)
from whispr import Segment, Task, WhisperCpp, read_progress, whisper_cpp_params
class RecordingTranscriber:
@ -34,20 +33,19 @@ class RecordingTranscriber:
class Event:
pass
class LoadedModelEvent(Event):
pass
@dataclass
class TranscribedNextChunkEvent(Event):
text: str
def __init__(self,
model_name: str, use_whisper_cpp: bool,
model_path: str, use_whisper_cpp: bool,
language: Optional[str], task: Task,
on_download_model_chunk: Callable[[
int, int], None] = lambda *_: None,
event_callback: Callable[[Event], None] = lambda *_: None,
input_device_index: Optional[int] = None) -> None:
self.model_path = model_path
self.use_whisper_cpp = use_whisper_cpp
self.current_stream = None
self.event_callback = event_callback
self.language = language
@ -62,24 +60,18 @@ class RecordingTranscriber:
self.mutex = threading.Lock()
self.text = ''
self.on_download_model_chunk = on_download_model_chunk
self.model_loader = ModelLoader(
name=model_name, use_whisper_cpp=use_whisper_cpp,)
def start_recording(self):
self.current_thread = Thread(target=self.process_queue)
self.current_thread.start()
def process_queue(self):
try:
model = self.model_loader.load(
on_download_model_chunk=self.on_download_model_chunk)
except Stopped:
return
model = WhisperCpp(
self.model_path) if self.use_whisper_cpp else whisper.load_model(self.model_path)
self.event_callback(self.LoadedModelEvent())
logging.debug('Recording, language = %s, task = %s, device = %s, sample rate = %s',
self.language, self.task, self.input_device_index, self.sample_rate)
logging.debug(
'Recording, language = %s, task = %s, device = %s, sample rate = %s, model_path = %s',
self.language, self.task, self.input_device_index, self.sample_rate, self.model_path)
self.current_stream = sounddevice.InputStream(
samplerate=self.sample_rate,
blocksize=1 * self.sample_rate, # 1 sec
@ -160,9 +152,6 @@ class RecordingTranscriber:
self.current_thread.join()
logging.debug('Recording thread terminated')
def stop_loading_model(self):
self.model_loader.stop()
class OutputFormat(enum.Enum):
TXT = 'txt'
@ -226,15 +215,12 @@ class FileTranscriber:
current_value: int
max_value: int
class LoadedModelEvent(Event):
pass
class CompletedTranscriptionEvent(Event):
pass
def __init__(
self,
model_name: str, use_whisper_cpp: bool,
model_path: str, use_whisper_cpp: bool,
language: Optional[str], task: Task, file_path: str,
output_file_path: str, output_format: OutputFormat,
word_level_timings: bool,
@ -249,12 +235,9 @@ class FileTranscriber:
self.open_file_on_complete = open_file_on_complete
self.output_format = output_format
self.word_level_timings = word_level_timings
self.model_name = model_name
self.model_path = model_path
self.use_whisper_cpp = use_whisper_cpp
self.on_download_model_chunk = on_download_model_chunk
self.model_loader = ModelLoader(self.model_name, self.use_whisper_cpp)
self.event_callback = event_callback
def start(self):
@ -262,21 +245,10 @@ class FileTranscriber:
self.current_thread.start()
def transcribe(self):
if self.stopped:
return
try:
model_path = self.model_loader.get_model_path(
on_download_model_chunk=self.on_download_model_chunk)
except Stopped:
return
self.event_callback(self.LoadedModelEvent())
time_started = datetime.datetime.now()
logging.debug(
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model_path = %s',
self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path)
self.file_path, self.language, self.task, self.output_file_path, self.output_format, self.model_path)
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
@ -288,7 +260,7 @@ class FileTranscriber:
self.current_process = multiprocessing.Process(
target=transcribe_whisper_cpp,
args=(
send_pipe, model_path, self.file_path,
send_pipe, self.model_path, self.file_path,
self.output_file_path, self.open_file_on_complete,
self.output_format,
self.language if self.language is not None else 'en',
@ -299,7 +271,7 @@ class FileTranscriber:
self.current_process = multiprocessing.Process(
target=transcribe_whisper,
args=(
send_pipe, model_path, self.file_path,
send_pipe, self.model_path, self.file_path,
self.language, self.task, self.output_file_path,
self.open_file_on_complete, self.output_format,
self.word_level_timings
@ -324,9 +296,6 @@ class FileTranscriber:
logging.debug('Completed file transcription, time taken = %s',
datetime.datetime.now()-time_started)
def stop_loading_model(self):
self.model_loader.stop()
def join(self):
if self.current_thread is not None:
self.current_thread.join()
@ -335,8 +304,6 @@ class FileTranscriber:
if self.stopped is False:
self.stopped = True
self.model_loader.stop()
if self.current_process is not None and self.current_process.is_alive():
self.current_process.terminate()
logging.debug('File transcription process terminated')

View file

@ -1,19 +1,35 @@
import logging
import os
import pathlib
import tempfile
import time
import pytest
from model_loader import ModelLoader
from transcriber import (FileTranscriber, OutputFormat, RecordingTranscriber,
to_timestamp)
from whispr import Task
def get_model_path(model_name: str, use_whisper_cpp: bool) -> str:
model_loader = ModelLoader(model_name, use_whisper_cpp)
model_path = ''
def on_load_model(path: str):
nonlocal model_path
model_path = path
model_loader.signals.completed.connect(on_load_model)
model_loader.run()
return model_path
class TestRecordingTranscriber:
def test_transcriber(self):
model_path = get_model_path('tiny', True)
transcriber = RecordingTranscriber(
model_name='tiny', use_whisper_cpp=True, language='en',
model_path=model_path, use_whisper_cpp=True, language='en',
task=Task.TRANSCRIBE)
assert transcriber is not None
@ -36,7 +52,8 @@ class TestFileTranscriber:
(False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle, un podcast'),
(False, OutputFormat.SRT, '1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'),
(False, OutputFormat.VTT, 'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'),
(True, OutputFormat.SRT, '1\n00:00:00.040 --> 00:00:00.359\n Bienvenue dans\n\n2\n00:00:00.359 --> 00:00:00.419\n Passe-'),
(True, OutputFormat.SRT,
'1\n00:00:00.040 --> 00:00:00.359\n Bienvenue dans\n\n2\n00:00:00.359 --> 00:00:00.419\n Passe-'),
])
def test_transcribe_whisper(self, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str):
output_file_path = tmp_path / f'whisper.{output_format.value.lower()}'
@ -46,8 +63,9 @@ class TestFileTranscriber:
def event_callback(event: FileTranscriber.Event):
events.append(event)
model_path = get_model_path('tiny', False)
transcriber = FileTranscriber(
model_name='tiny', use_whisper_cpp=False, language='fr',
model_path=model_path, use_whisper_cpp=False, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path.as_posix(), output_format=output_format,
open_file_on_complete=False, event_callback=event_callback,
@ -78,19 +96,19 @@ class TestFileTranscriber:
def event_callback(event: FileTranscriber.Event):
events.append(event)
model_path = get_model_path('tiny', False)
transcriber = FileTranscriber(
model_name='tiny', use_whisper_cpp=False, language='fr',
model_path=model_path, use_whisper_cpp=False, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path, output_format=OutputFormat.TXT,
open_file_on_complete=False, event_callback=event_callback,
word_level_timings=False)
transcriber.start()
time.sleep(1)
transcriber.stop()
# Assert that file was not created and there was no completed progress event
# Assert that file was not created
assert os.path.isfile(output_file_path) is False
assert any([isinstance(event, FileTranscriber.ProgressEvent)
and event.current_value == event.max_value for event in events]) is False
def test_transcribe_whisper_cpp(self):
output_file_path = os.path.join(
@ -103,8 +121,9 @@ class TestFileTranscriber:
def event_callback(event: FileTranscriber.Event):
events.append(event)
model_path = get_model_path('tiny', True)
transcriber = FileTranscriber(
model_name='tiny', use_whisper_cpp=True, language='fr',
model_path=model_path, use_whisper_cpp=True, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path, output_format=OutputFormat.TXT,
open_file_on_complete=False, event_callback=event_callback,

130
whispr.py
View file

@ -1,24 +1,13 @@
import ctypes
import enum
import hashlib
import logging
import multiprocessing
import os
import warnings
from dataclasses import dataclass
from multiprocessing.connection import Connection
from queue import Empty
from threading import Thread
from typing import Any, Callable, List, Union
import numpy as np
import requests
import whisper
from appdirs import user_cache_dir
from tqdm import tqdm
from whisper import Whisper
from conn import pipe_stderr
# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
@ -100,125 +89,6 @@ class WhisperCpp:
whisper_cpp.whisper_free((self.ctx))
class ModelLoader:
process: multiprocessing.Process
model_path_queue: multiprocessing.Queue
def __init__(self, name: str, use_whisper_cpp=False) -> None:
self.name = name
self.use_whisper_cpp = use_whisper_cpp
self.recv_pipe, self.send_pipe = multiprocessing.Pipe(duplex=False)
self.model_path_queue = multiprocessing.Queue()
self.process = multiprocessing.Process(
target=self.load_whisper_cpp_model if self.use_whisper_cpp else self.load_whisper_model,
args=(self.send_pipe, self.model_path_queue, self.name))
def get_model_path(self, on_download_model_chunk: Callable[[int, int], None] = lambda *_: None) -> str:
logging.debug(
'Loading model = %s, whisper.cpp = %s', self.name, self.use_whisper_cpp)
# Fixes an issue with the pickling of a torch model from another process
os.environ["no_proxy"] = '*'
on_download_model_chunk(0, 100)
self.process.start()
thread = Thread(target=read_progress, args=(
self.recv_pipe, self.use_whisper_cpp, on_download_model_chunk))
thread.start()
self.process.join()
self.recv_pipe.close()
self.send_pipe.close()
on_download_model_chunk(100, 100)
try:
model_path = self.model_path_queue.get(block=False)
logging.debug('Model path = %s', model_path)
return model_path
except Empty as exc:
raise Stopped from exc
def load(self, on_download_model_chunk: Callable[[int, int], None] = lambda *_: None) -> Union[Whisper, WhisperCpp]:
model_path = self.get_model_path(on_download_model_chunk)
return WhisperCpp(model_path) if self.use_whisper_cpp else whisper.load_model(model_path)
def load_whisper_cpp_model(self, stderr_conn: Connection, queue: multiprocessing.Queue, name: str):
path = download_model(name, use_whisper_cpp=True)
queue.put(path)
def load_whisper_model(self, stderr_conn: Connection, queue: multiprocessing.Queue, name: str):
with pipe_stderr(stderr_conn):
path = download_model(name, use_whisper_cpp=False)
queue.put(path)
def stop(self):
if self.process.is_alive():
self.process.terminate()
def is_alive(self):
return self.process.is_alive()
MODELS_SHA256 = {
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
}
def download_model(name: str, use_whisper_cpp=False):
if use_whisper_cpp:
root = user_cache_dir('Buzz')
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{name}.bin'
else:
root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
url = whisper._MODELS[name]
os.makedirs(root, exist_ok=True)
model_path = os.path.join(root, os.path.basename(url))
if os.path.exists(model_path) and not os.path.isfile(model_path):
raise RuntimeError(
f"{model_path} exists and is not a regular file")
expected_sha256 = MODELS_SHA256[name] if use_whisper_cpp else url.split(
"/")[-2]
if os.path.isfile(model_path):
model_bytes = open(model_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_path
else:
warnings.warn(
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
source.raise_for_status()
total_size = int(source.headers.get('Content-Length', 0))
with tqdm(total=total_size, ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
for chunk in source.iter_content(chunk_size=8192):
output.write(chunk)
loop.update(len(chunk))
model_bytes = open(model_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.")
return model_path
# tqdm progress line looks like: " 54%|█████ |"
def tqdm_progress(line: str):
percent_progress = line.split('|')[0].strip().strip('%')