mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-29 05:00:21 +02:00
Refactor model loading (#191)
This commit is contained in:
parent
3486056dea
commit
219f20d0e2
2
Makefile
2
Makefile
|
@ -47,7 +47,7 @@ clean:
|
|||
rm -rf dist/* || true
|
||||
|
||||
test: whisper_cpp.py
|
||||
pytest --cov
|
||||
pytest --cov --cov-report=html
|
||||
|
||||
dist/Buzz dist/Buzz.app: whisper_cpp.py
|
||||
pyinstaller --noconfirm Buzz.spec
|
||||
|
|
144
gui.py
144
gui.py
|
@ -10,7 +10,7 @@ import humanize
|
|||
import sounddevice
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QTimer,
|
||||
QUrl, pyqtSignal)
|
||||
QUrl, pyqtSignal, QThreadPool)
|
||||
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
||||
QKeySequence, QPixmap, QTextCursor)
|
||||
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
||||
|
@ -23,6 +23,7 @@ from whisper import tokenizer
|
|||
from __version__ import VERSION
|
||||
from transcriber import FileTranscriber, OutputFormat, RecordingTranscriber
|
||||
from whispr import LOADED_WHISPER_DLL, Task
|
||||
from model_loader import ModelLoader
|
||||
|
||||
APP_NAME = 'Buzz'
|
||||
|
||||
|
@ -271,13 +272,13 @@ class FileTranscriberObject(QObject):
|
|||
transcriber: FileTranscriber
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, use_whisper_cpp: bool, language: Optional[str],
|
||||
self, model_path: str, use_whisper_cpp: bool, language: Optional[str],
|
||||
task: Task, file_path: str, output_file_path: str,
|
||||
output_format: OutputFormat, word_level_timings: bool,
|
||||
parent: Optional['QObject'], *args) -> None:
|
||||
super().__init__(parent, *args)
|
||||
self.transcriber = FileTranscriber(
|
||||
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
|
||||
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
|
||||
on_download_model_chunk=self.on_download_model_progress,
|
||||
language=language, task=task, file_path=file_path,
|
||||
output_file_path=output_file_path, output_format=output_format,
|
||||
|
@ -299,9 +300,6 @@ class FileTranscriberObject(QObject):
|
|||
def stop(self):
|
||||
self.transcriber.stop()
|
||||
|
||||
def stop_loading_model(self):
|
||||
self.transcriber.stop_loading_model()
|
||||
|
||||
|
||||
class RecordingTranscriberObject(QObject):
|
||||
"""
|
||||
|
@ -313,11 +311,11 @@ class RecordingTranscriberObject(QObject):
|
|||
download_model_progress = pyqtSignal(tuple)
|
||||
transcriber: RecordingTranscriber
|
||||
|
||||
def __init__(self, model_name, use_whisper_cpp, language: Optional[str],
|
||||
def __init__(self, model_path: str, use_whisper_cpp, language: Optional[str],
|
||||
task: Task, input_device_index: Optional[int], parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(parent, *args)
|
||||
self.transcriber = RecordingTranscriber(
|
||||
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
|
||||
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
|
||||
on_download_model_chunk=self.on_download_model_progress, language=language,
|
||||
event_callback=self.event_callback, task=task,
|
||||
input_device_index=input_device_index)
|
||||
|
@ -334,9 +332,6 @@ class RecordingTranscriberObject(QObject):
|
|||
def stop_recording(self):
|
||||
self.transcriber.stop_recording()
|
||||
|
||||
def stop_loading_model(self):
|
||||
self.transcriber.stop_loading_model()
|
||||
|
||||
|
||||
class TimerLabel(QLabel):
|
||||
start_time: Optional[QDateTime]
|
||||
|
@ -378,6 +373,11 @@ def get_model_name(quality: Quality) -> str:
|
|||
}[quality][0]
|
||||
|
||||
|
||||
def show_model_download_error_dialog(parent: QWidget, error: str):
|
||||
message = f'Unable to load the Whisper model: {error}. Please retry or check the application logs for more information.'
|
||||
QMessageBox.critical(parent, '', message)
|
||||
|
||||
|
||||
class FileTranscriberWidget(QWidget):
|
||||
selected_quality = Quality.VERY_LOW
|
||||
selected_language: Optional[str] = None
|
||||
|
@ -387,6 +387,7 @@ class FileTranscriberWidget(QWidget):
|
|||
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None
|
||||
file_transcriber: Optional[FileTranscriberObject] = None
|
||||
model_loader: Optional[ModelLoader] = None
|
||||
|
||||
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
|
||||
super().__init__(parent)
|
||||
|
@ -445,6 +446,7 @@ class FileTranscriberWidget(QWidget):
|
|||
layout.addWidget(widget, row_index, col_offset, 1, col_width)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.pool = QThreadPool()
|
||||
|
||||
def on_quality_changed(self, quality: Quality):
|
||||
self.selected_quality = quality
|
||||
|
@ -476,26 +478,37 @@ class FileTranscriberWidget(QWidget):
|
|||
self.run_button.setDisabled(True)
|
||||
model_name = get_model_name(self.selected_quality)
|
||||
|
||||
self.file_transcriber = FileTranscriberObject(
|
||||
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
|
||||
file_path=self.file_path,
|
||||
language=self.selected_language, task=self.selected_task,
|
||||
output_file_path=output_file, output_format=self.selected_output_format,
|
||||
word_level_timings=self.enabled_word_level_timings,
|
||||
parent=self)
|
||||
self.file_transcriber.download_model_progress.connect(
|
||||
self.on_download_model_progress)
|
||||
self.file_transcriber.event_received.connect(
|
||||
self.on_transcriber_event)
|
||||
def start_file_transcription(model_path: str):
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
self.file_transcriber.start()
|
||||
self.file_transcriber = FileTranscriberObject(
|
||||
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
|
||||
file_path=self.file_path,
|
||||
language=self.selected_language, task=self.selected_task,
|
||||
output_file_path=output_file, output_format=self.selected_output_format,
|
||||
word_level_timings=self.enabled_word_level_timings,
|
||||
parent=self)
|
||||
self.file_transcriber.event_received.connect(
|
||||
self.on_transcriber_event)
|
||||
|
||||
self.file_transcriber.start()
|
||||
|
||||
self.model_loader = ModelLoader(
|
||||
name=model_name, use_whisper_cpp=use_whisper_cpp)
|
||||
self.model_loader.signals.progress.connect(
|
||||
self.on_download_model_progress)
|
||||
self.model_loader.signals.completed.connect(start_file_transcription)
|
||||
self.model_loader.signals.error.connect(self.on_download_model_error)
|
||||
|
||||
self.pool.start(self.model_loader)
|
||||
|
||||
def on_download_model_progress(self, progress: Tuple[int, int]):
|
||||
(current_size, _) = progress
|
||||
(current_size, total_size) = progress
|
||||
|
||||
if self.model_download_progress_dialog is None:
|
||||
self.model_download_progress_dialog = DownloadModelProgressDialog(
|
||||
total_size=100, parent=self)
|
||||
total_size=total_size, parent=self)
|
||||
self.model_download_progress_dialog.canceled.connect(
|
||||
self.on_cancel_model_progress_dialog)
|
||||
|
||||
|
@ -503,10 +516,12 @@ class FileTranscriberWidget(QWidget):
|
|||
self.model_download_progress_dialog.setValue(
|
||||
current_size=current_size)
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
show_model_download_error_dialog(self, error)
|
||||
self.reset_transcription()
|
||||
|
||||
def on_transcriber_event(self, event: FileTranscriber.Event):
|
||||
if isinstance(event, FileTranscriber.LoadedModelEvent):
|
||||
self.reset_model_download()
|
||||
elif isinstance(event, FileTranscriber.ProgressEvent):
|
||||
if isinstance(event, FileTranscriber.ProgressEvent):
|
||||
current_size = event.current_value
|
||||
total_size = event.max_value
|
||||
|
||||
|
@ -535,8 +550,8 @@ class FileTranscriberWidget(QWidget):
|
|||
self.transcriber_progress_dialog = None
|
||||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
if self.file_transcriber is not None:
|
||||
self.file_transcriber.stop_loading_model()
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.stop()
|
||||
self.reset_model_download()
|
||||
|
||||
def reset_model_download(self):
|
||||
|
@ -583,6 +598,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
settings: Settings
|
||||
transcriber: Optional[RecordingTranscriberObject] = None
|
||||
model_loader: Optional[ModelLoader] = None
|
||||
|
||||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
super().__init__(parent)
|
||||
|
@ -637,6 +653,8 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
self.setLayout(layout)
|
||||
|
||||
self.pool = QThreadPool()
|
||||
|
||||
def on_device_changed(self, device_id: int):
|
||||
self.selected_device_id = device_id
|
||||
|
||||
|
@ -663,32 +681,36 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
model_name = get_model_name(self.selected_quality)
|
||||
|
||||
self.transcriber = RecordingTranscriberObject(
|
||||
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
|
||||
language=self.selected_language, task=self.selected_task,
|
||||
input_device_index=self.selected_device_id,
|
||||
parent=self
|
||||
)
|
||||
self.transcriber.event_changed.connect(
|
||||
self.on_transcriber_event_changed)
|
||||
self.transcriber.download_model_progress.connect(
|
||||
self.on_download_model_progress)
|
||||
|
||||
self.transcriber.start_recording()
|
||||
|
||||
def on_transcriber_event_changed(self, event: RecordingTranscriber.Event):
|
||||
if isinstance(event, RecordingTranscriber.LoadedModelEvent):
|
||||
def start_recording_transcription(model_path: str):
|
||||
# Clear text box placeholder because the first chunk takes a while to process
|
||||
self.text_box.setPlaceholderText('')
|
||||
self.timer_label.start_timer()
|
||||
self.record_button.setDisabled(False)
|
||||
self.reset_model_download()
|
||||
elif isinstance(event, RecordingTranscriber.TranscribedNextChunkEvent):
|
||||
text = event.text.strip()
|
||||
if len(text) > 0:
|
||||
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
|
||||
self.text_box.insertPlainText(text + '\n\n')
|
||||
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
self.transcriber = RecordingTranscriberObject(
|
||||
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
|
||||
language=self.selected_language, task=self.selected_task,
|
||||
input_device_index=self.selected_device_id,
|
||||
parent=self
|
||||
)
|
||||
self.transcriber.event_changed.connect(
|
||||
self.on_transcriber_event_changed)
|
||||
self.transcriber.download_model_progress.connect(
|
||||
self.on_download_model_progress)
|
||||
|
||||
self.transcriber.start_recording()
|
||||
|
||||
self.model_loader = ModelLoader(
|
||||
name=model_name, use_whisper_cpp=use_whisper_cpp)
|
||||
self.model_loader.signals.progress.connect(
|
||||
self.on_download_model_progress)
|
||||
self.model_loader.signals.completed.connect(
|
||||
start_recording_transcription)
|
||||
self.model_loader.signals.error.connect(self.on_download_model_error)
|
||||
|
||||
self.pool.start(self.model_loader)
|
||||
|
||||
def on_download_model_progress(self, progress: Tuple[int, int]):
|
||||
(current_size, _) = progress
|
||||
|
@ -703,14 +725,26 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.model_download_progress_dialog.setValue(
|
||||
current_size=current_size)
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
show_model_download_error_dialog(self, error)
|
||||
self.stop_recording()
|
||||
|
||||
def on_transcriber_event_changed(self, event: RecordingTranscriber.Event):
|
||||
if isinstance(event, RecordingTranscriber.TranscribedNextChunkEvent):
|
||||
text = event.text.strip()
|
||||
if len(text) > 0:
|
||||
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
|
||||
self.text_box.insertPlainText(text + '\n\n')
|
||||
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
|
||||
|
||||
def stop_recording(self):
|
||||
if self.transcriber is not None:
|
||||
self.transcriber.stop_recording()
|
||||
self.timer_label.stop_timer()
|
||||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
if self.transcriber is not None:
|
||||
self.transcriber.stop_loading_model()
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.stop()
|
||||
self.reset_model_download()
|
||||
self.record_button.force_stop()
|
||||
self.record_button.setDisabled(False)
|
||||
|
@ -736,7 +770,7 @@ class AppIcon(QIcon):
|
|||
|
||||
|
||||
class AboutDialog(QDialog):
|
||||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
def __init__(self, parent: Optional[QWidget]=None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.setFixedSize(200, 200)
|
||||
|
|
|
@ -7,7 +7,7 @@ from PyQt6.QtCore import Qt
|
|||
|
||||
from gui import (Application, AudioDevicesComboBox,
|
||||
DownloadModelProgressDialog, FileTranscriberWidget,
|
||||
LanguagesComboBox, MainWindow, OutputFormatsComboBox, Quality, Settings,
|
||||
LanguagesComboBox, MainWindow, OutputFormatsComboBox, Quality, Settings, AboutDialog,
|
||||
QualityComboBox, TranscriberProgressDialog)
|
||||
from transcriber import OutputFormat
|
||||
|
||||
|
@ -204,3 +204,9 @@ class TestSettings:
|
|||
|
||||
settings.set_enable_ggml_inference(False)
|
||||
assert settings.get_enable_ggml_inference() is False
|
||||
|
||||
|
||||
class TestAboutDialog:
|
||||
def test_should_create(self):
|
||||
dialog = AboutDialog()
|
||||
assert dialog is not None
|
||||
|
|
96
model_loader.py
Normal file
96
model_loader.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import requests
|
||||
import whisper
|
||||
from platformdirs import user_cache_dir
|
||||
from PyQt6.QtCore import QObject, QRunnable, pyqtSignal, pyqtSlot
|
||||
|
||||
MODELS_SHA256 = {
|
||||
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
|
||||
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
|
||||
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
|
||||
}
|
||||
|
||||
|
||||
class ModelLoader(QRunnable):
|
||||
class Signals(QObject):
|
||||
progress = pyqtSignal(tuple) # (current, total)
|
||||
completed = pyqtSignal(str)
|
||||
error = pyqtSignal(str)
|
||||
|
||||
signals: Signals
|
||||
stopped = False
|
||||
|
||||
def __init__(self, name: str, use_whisper_cpp=False) -> None:
|
||||
super(ModelLoader, self).__init__()
|
||||
self.name = name
|
||||
self.use_whisper_cpp = use_whisper_cpp
|
||||
self.signals = self.Signals()
|
||||
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
try:
|
||||
if self.use_whisper_cpp:
|
||||
root = user_cache_dir('Buzz')
|
||||
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{self.name}.bin'
|
||||
else:
|
||||
root = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
url = whisper._MODELS[self.name]
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
model_path = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(model_path) and not os.path.isfile(model_path):
|
||||
raise RuntimeError(
|
||||
f"{model_path} exists and is not a regular file")
|
||||
|
||||
expected_sha256 = MODELS_SHA256[self.name] if self.use_whisper_cpp else url.split(
|
||||
"/")[-2]
|
||||
if os.path.isfile(model_path):
|
||||
model_bytes = open(model_path, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
self.signals.completed.emit(model_path)
|
||||
return
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
|
||||
source.raise_for_status()
|
||||
total_size = int(source.headers.get('Content-Length', 0))
|
||||
current = 0
|
||||
self.signals.progress.emit((0, total_size))
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
if self.stopped:
|
||||
return
|
||||
output.write(chunk)
|
||||
current += len(chunk)
|
||||
self.signals.progress.emit((current, total_size))
|
||||
|
||||
model_bytes = open(model_path, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.")
|
||||
|
||||
self.signals.completed.emit(model_path)
|
||||
except RuntimeError as exc:
|
||||
self.signals.error.emit(str(exc))
|
||||
logging.exception('')
|
||||
except requests.RequestException:
|
||||
self.signals.error.emit('A connection error occurred.')
|
||||
logging.exception('')
|
||||
except Exception:
|
||||
self.signals.error.emit('An unknown error occurred.')
|
||||
logging.exception('')
|
||||
|
||||
def stop(self):
|
||||
self.stopped = True
|
|
@ -19,8 +19,7 @@ import whisper
|
|||
from sounddevice import PortAudioError
|
||||
|
||||
from conn import pipe_stderr, pipe_stdout
|
||||
from whispr import (ModelLoader, Segment, Stopped, Task, WhisperCpp,
|
||||
read_progress, whisper_cpp_params)
|
||||
from whispr import Segment, Task, WhisperCpp, read_progress, whisper_cpp_params
|
||||
|
||||
|
||||
class RecordingTranscriber:
|
||||
|
@ -34,20 +33,19 @@ class RecordingTranscriber:
|
|||
class Event:
|
||||
pass
|
||||
|
||||
class LoadedModelEvent(Event):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class TranscribedNextChunkEvent(Event):
|
||||
text: str
|
||||
|
||||
def __init__(self,
|
||||
model_name: str, use_whisper_cpp: bool,
|
||||
model_path: str, use_whisper_cpp: bool,
|
||||
language: Optional[str], task: Task,
|
||||
on_download_model_chunk: Callable[[
|
||||
int, int], None] = lambda *_: None,
|
||||
event_callback: Callable[[Event], None] = lambda *_: None,
|
||||
input_device_index: Optional[int] = None) -> None:
|
||||
self.model_path = model_path
|
||||
self.use_whisper_cpp = use_whisper_cpp
|
||||
self.current_stream = None
|
||||
self.event_callback = event_callback
|
||||
self.language = language
|
||||
|
@ -62,24 +60,18 @@ class RecordingTranscriber:
|
|||
self.mutex = threading.Lock()
|
||||
self.text = ''
|
||||
self.on_download_model_chunk = on_download_model_chunk
|
||||
self.model_loader = ModelLoader(
|
||||
name=model_name, use_whisper_cpp=use_whisper_cpp,)
|
||||
|
||||
def start_recording(self):
|
||||
self.current_thread = Thread(target=self.process_queue)
|
||||
self.current_thread.start()
|
||||
|
||||
def process_queue(self):
|
||||
try:
|
||||
model = self.model_loader.load(
|
||||
on_download_model_chunk=self.on_download_model_chunk)
|
||||
except Stopped:
|
||||
return
|
||||
model = WhisperCpp(
|
||||
self.model_path) if self.use_whisper_cpp else whisper.load_model(self.model_path)
|
||||
|
||||
self.event_callback(self.LoadedModelEvent())
|
||||
|
||||
logging.debug('Recording, language = %s, task = %s, device = %s, sample rate = %s',
|
||||
self.language, self.task, self.input_device_index, self.sample_rate)
|
||||
logging.debug(
|
||||
'Recording, language = %s, task = %s, device = %s, sample rate = %s, model_path = %s',
|
||||
self.language, self.task, self.input_device_index, self.sample_rate, self.model_path)
|
||||
self.current_stream = sounddevice.InputStream(
|
||||
samplerate=self.sample_rate,
|
||||
blocksize=1 * self.sample_rate, # 1 sec
|
||||
|
@ -160,9 +152,6 @@ class RecordingTranscriber:
|
|||
self.current_thread.join()
|
||||
logging.debug('Recording thread terminated')
|
||||
|
||||
def stop_loading_model(self):
|
||||
self.model_loader.stop()
|
||||
|
||||
|
||||
class OutputFormat(enum.Enum):
|
||||
TXT = 'txt'
|
||||
|
@ -226,15 +215,12 @@ class FileTranscriber:
|
|||
current_value: int
|
||||
max_value: int
|
||||
|
||||
class LoadedModelEvent(Event):
|
||||
pass
|
||||
|
||||
class CompletedTranscriptionEvent(Event):
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str, use_whisper_cpp: bool,
|
||||
model_path: str, use_whisper_cpp: bool,
|
||||
language: Optional[str], task: Task, file_path: str,
|
||||
output_file_path: str, output_format: OutputFormat,
|
||||
word_level_timings: bool,
|
||||
|
@ -249,12 +235,9 @@ class FileTranscriber:
|
|||
self.open_file_on_complete = open_file_on_complete
|
||||
self.output_format = output_format
|
||||
self.word_level_timings = word_level_timings
|
||||
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
self.use_whisper_cpp = use_whisper_cpp
|
||||
self.on_download_model_chunk = on_download_model_chunk
|
||||
|
||||
self.model_loader = ModelLoader(self.model_name, self.use_whisper_cpp)
|
||||
self.event_callback = event_callback
|
||||
|
||||
def start(self):
|
||||
|
@ -262,21 +245,10 @@ class FileTranscriber:
|
|||
self.current_thread.start()
|
||||
|
||||
def transcribe(self):
|
||||
if self.stopped:
|
||||
return
|
||||
|
||||
try:
|
||||
model_path = self.model_loader.get_model_path(
|
||||
on_download_model_chunk=self.on_download_model_chunk)
|
||||
except Stopped:
|
||||
return
|
||||
|
||||
self.event_callback(self.LoadedModelEvent())
|
||||
|
||||
time_started = datetime.datetime.now()
|
||||
logging.debug(
|
||||
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model_path = %s',
|
||||
self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path)
|
||||
self.file_path, self.language, self.task, self.output_file_path, self.output_format, self.model_path)
|
||||
|
||||
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
|
@ -288,7 +260,7 @@ class FileTranscriber:
|
|||
self.current_process = multiprocessing.Process(
|
||||
target=transcribe_whisper_cpp,
|
||||
args=(
|
||||
send_pipe, model_path, self.file_path,
|
||||
send_pipe, self.model_path, self.file_path,
|
||||
self.output_file_path, self.open_file_on_complete,
|
||||
self.output_format,
|
||||
self.language if self.language is not None else 'en',
|
||||
|
@ -299,7 +271,7 @@ class FileTranscriber:
|
|||
self.current_process = multiprocessing.Process(
|
||||
target=transcribe_whisper,
|
||||
args=(
|
||||
send_pipe, model_path, self.file_path,
|
||||
send_pipe, self.model_path, self.file_path,
|
||||
self.language, self.task, self.output_file_path,
|
||||
self.open_file_on_complete, self.output_format,
|
||||
self.word_level_timings
|
||||
|
@ -324,9 +296,6 @@ class FileTranscriber:
|
|||
logging.debug('Completed file transcription, time taken = %s',
|
||||
datetime.datetime.now()-time_started)
|
||||
|
||||
def stop_loading_model(self):
|
||||
self.model_loader.stop()
|
||||
|
||||
def join(self):
|
||||
if self.current_thread is not None:
|
||||
self.current_thread.join()
|
||||
|
@ -335,8 +304,6 @@ class FileTranscriber:
|
|||
if self.stopped is False:
|
||||
self.stopped = True
|
||||
|
||||
self.model_loader.stop()
|
||||
|
||||
if self.current_process is not None and self.current_process.is_alive():
|
||||
self.current_process.terminate()
|
||||
logging.debug('File transcription process terminated')
|
||||
|
|
|
@ -1,19 +1,35 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from model_loader import ModelLoader
|
||||
from transcriber import (FileTranscriber, OutputFormat, RecordingTranscriber,
|
||||
to_timestamp)
|
||||
from whispr import Task
|
||||
|
||||
|
||||
def get_model_path(model_name: str, use_whisper_cpp: bool) -> str:
|
||||
model_loader = ModelLoader(model_name, use_whisper_cpp)
|
||||
model_path = ''
|
||||
|
||||
def on_load_model(path: str):
|
||||
nonlocal model_path
|
||||
model_path = path
|
||||
|
||||
model_loader.signals.completed.connect(on_load_model)
|
||||
model_loader.run()
|
||||
return model_path
|
||||
|
||||
|
||||
class TestRecordingTranscriber:
|
||||
def test_transcriber(self):
|
||||
|
||||
model_path = get_model_path('tiny', True)
|
||||
transcriber = RecordingTranscriber(
|
||||
model_name='tiny', use_whisper_cpp=True, language='en',
|
||||
model_path=model_path, use_whisper_cpp=True, language='en',
|
||||
task=Task.TRANSCRIBE)
|
||||
assert transcriber is not None
|
||||
|
||||
|
@ -36,7 +52,8 @@ class TestFileTranscriber:
|
|||
(False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle, un podcast'),
|
||||
(False, OutputFormat.SRT, '1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'),
|
||||
(False, OutputFormat.VTT, 'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'),
|
||||
(True, OutputFormat.SRT, '1\n00:00:00.040 --> 00:00:00.359\n Bienvenue dans\n\n2\n00:00:00.359 --> 00:00:00.419\n Passe-'),
|
||||
(True, OutputFormat.SRT,
|
||||
'1\n00:00:00.040 --> 00:00:00.359\n Bienvenue dans\n\n2\n00:00:00.359 --> 00:00:00.419\n Passe-'),
|
||||
])
|
||||
def test_transcribe_whisper(self, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str):
|
||||
output_file_path = tmp_path / f'whisper.{output_format.value.lower()}'
|
||||
|
@ -46,8 +63,9 @@ class TestFileTranscriber:
|
|||
def event_callback(event: FileTranscriber.Event):
|
||||
events.append(event)
|
||||
|
||||
model_path = get_model_path('tiny', False)
|
||||
transcriber = FileTranscriber(
|
||||
model_name='tiny', use_whisper_cpp=False, language='fr',
|
||||
model_path=model_path, use_whisper_cpp=False, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path.as_posix(), output_format=output_format,
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
|
@ -78,19 +96,19 @@ class TestFileTranscriber:
|
|||
def event_callback(event: FileTranscriber.Event):
|
||||
events.append(event)
|
||||
|
||||
model_path = get_model_path('tiny', False)
|
||||
transcriber = FileTranscriber(
|
||||
model_name='tiny', use_whisper_cpp=False, language='fr',
|
||||
model_path=model_path, use_whisper_cpp=False, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
word_level_timings=False)
|
||||
transcriber.start()
|
||||
time.sleep(1)
|
||||
transcriber.stop()
|
||||
|
||||
# Assert that file was not created and there was no completed progress event
|
||||
# Assert that file was not created
|
||||
assert os.path.isfile(output_file_path) is False
|
||||
assert any([isinstance(event, FileTranscriber.ProgressEvent)
|
||||
and event.current_value == event.max_value for event in events]) is False
|
||||
|
||||
def test_transcribe_whisper_cpp(self):
|
||||
output_file_path = os.path.join(
|
||||
|
@ -103,8 +121,9 @@ class TestFileTranscriber:
|
|||
def event_callback(event: FileTranscriber.Event):
|
||||
events.append(event)
|
||||
|
||||
model_path = get_model_path('tiny', True)
|
||||
transcriber = FileTranscriber(
|
||||
model_name='tiny', use_whisper_cpp=True, language='fr',
|
||||
model_path=model_path, use_whisper_cpp=True, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
|
|
130
whispr.py
130
whispr.py
|
@ -1,24 +1,13 @@
|
|||
import ctypes
|
||||
import enum
|
||||
import hashlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.connection import Connection
|
||||
from queue import Empty
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import whisper
|
||||
from appdirs import user_cache_dir
|
||||
from tqdm import tqdm
|
||||
from whisper import Whisper
|
||||
|
||||
from conn import pipe_stderr
|
||||
|
||||
# Catch exception from whisper.dll not getting loaded.
|
||||
# TODO: Remove flag and try-except when issue with loading
|
||||
|
@ -100,125 +89,6 @@ class WhisperCpp:
|
|||
whisper_cpp.whisper_free((self.ctx))
|
||||
|
||||
|
||||
class ModelLoader:
|
||||
process: multiprocessing.Process
|
||||
model_path_queue: multiprocessing.Queue
|
||||
|
||||
def __init__(self, name: str, use_whisper_cpp=False) -> None:
|
||||
self.name = name
|
||||
self.use_whisper_cpp = use_whisper_cpp
|
||||
|
||||
self.recv_pipe, self.send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
self.model_path_queue = multiprocessing.Queue()
|
||||
self.process = multiprocessing.Process(
|
||||
target=self.load_whisper_cpp_model if self.use_whisper_cpp else self.load_whisper_model,
|
||||
args=(self.send_pipe, self.model_path_queue, self.name))
|
||||
|
||||
def get_model_path(self, on_download_model_chunk: Callable[[int, int], None] = lambda *_: None) -> str:
|
||||
logging.debug(
|
||||
'Loading model = %s, whisper.cpp = %s', self.name, self.use_whisper_cpp)
|
||||
|
||||
# Fixes an issue with the pickling of a torch model from another process
|
||||
os.environ["no_proxy"] = '*'
|
||||
|
||||
on_download_model_chunk(0, 100)
|
||||
|
||||
self.process.start()
|
||||
|
||||
thread = Thread(target=read_progress, args=(
|
||||
self.recv_pipe, self.use_whisper_cpp, on_download_model_chunk))
|
||||
thread.start()
|
||||
|
||||
self.process.join()
|
||||
|
||||
self.recv_pipe.close()
|
||||
self.send_pipe.close()
|
||||
|
||||
on_download_model_chunk(100, 100)
|
||||
try:
|
||||
model_path = self.model_path_queue.get(block=False)
|
||||
logging.debug('Model path = %s', model_path)
|
||||
return model_path
|
||||
except Empty as exc:
|
||||
raise Stopped from exc
|
||||
|
||||
def load(self, on_download_model_chunk: Callable[[int, int], None] = lambda *_: None) -> Union[Whisper, WhisperCpp]:
|
||||
|
||||
model_path = self.get_model_path(on_download_model_chunk)
|
||||
|
||||
return WhisperCpp(model_path) if self.use_whisper_cpp else whisper.load_model(model_path)
|
||||
|
||||
def load_whisper_cpp_model(self, stderr_conn: Connection, queue: multiprocessing.Queue, name: str):
|
||||
path = download_model(name, use_whisper_cpp=True)
|
||||
queue.put(path)
|
||||
|
||||
def load_whisper_model(self, stderr_conn: Connection, queue: multiprocessing.Queue, name: str):
|
||||
with pipe_stderr(stderr_conn):
|
||||
path = download_model(name, use_whisper_cpp=False)
|
||||
queue.put(path)
|
||||
|
||||
def stop(self):
|
||||
if self.process.is_alive():
|
||||
self.process.terminate()
|
||||
|
||||
def is_alive(self):
|
||||
return self.process.is_alive()
|
||||
|
||||
|
||||
MODELS_SHA256 = {
|
||||
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
|
||||
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
|
||||
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
|
||||
}
|
||||
|
||||
|
||||
def download_model(name: str, use_whisper_cpp=False):
|
||||
if use_whisper_cpp:
|
||||
root = user_cache_dir('Buzz')
|
||||
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{name}.bin'
|
||||
else:
|
||||
root = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
url = whisper._MODELS[name]
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
model_path = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(model_path) and not os.path.isfile(model_path):
|
||||
raise RuntimeError(
|
||||
f"{model_path} exists and is not a regular file")
|
||||
|
||||
expected_sha256 = MODELS_SHA256[name] if use_whisper_cpp else url.split(
|
||||
"/")[-2]
|
||||
if os.path.isfile(model_path):
|
||||
model_bytes = open(model_path, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_path
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
|
||||
source.raise_for_status()
|
||||
total_size = int(source.headers.get('Content-Length', 0))
|
||||
with tqdm(total=total_size, ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
output.write(chunk)
|
||||
loop.update(len(chunk))
|
||||
|
||||
model_bytes = open(model_path, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.")
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
# tqdm progress line looks like: " 54%|█████ |"
|
||||
def tqdm_progress(line: str):
|
||||
percent_progress = line.split('|')[0].strip().strip('%')
|
||||
|
|
Loading…
Reference in a new issue