Add stable timings (#159)

This commit is contained in:
Chidi Williams 2022-11-12 16:31:08 +00:00 committed by GitHub
parent a1b9097133
commit 392f9cb469
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 129 additions and 86 deletions

11
.coveragerc Normal file
View file

@ -0,0 +1,11 @@
[run]
omit =
whisper_cpp.py
*_test.py
stable_ts/*
[html]
directory = coverage/html
[report]
fail_under = 75

1
.gitignore vendored
View file

@ -3,6 +3,7 @@ __pycache__/
build/
.pytest_cache/
.coverage*
!.coveragerc
.env
htmlcov/
libwhisper.*

4
.gitmodules vendored
View file

@ -1,3 +1,7 @@
[submodule "whisper.cpp"]
path = whisper.cpp
url = https://github.com/chidiwilliams/whisper.cpp
[submodule "stable_ts"]
path = stable_ts
url = https://github.com/chidiwilliams/stable-ts
branch = main

View file

@ -43,7 +43,7 @@ clean:
rm -rf dist/* || true
test: whisper_cpp.py
pytest --cov --cov-fail-under=69 --cov-report html
pytest --cov
dist/Buzz: whisper_cpp.py
pyinstaller --noconfirm Buzz.spec

33
gui.py
View file

@ -12,10 +12,10 @@ from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QTimer,
QUrl, pyqtSignal)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor)
from PyQt6.QtWidgets import (QApplication, QComboBox, QDialog, QFileDialog,
QGridLayout, QLabel, QMainWindow, QMessageBox,
QPlainTextEdit, QProgressDialog, QPushButton,
QVBoxLayout, QWidget)
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QFileDialog, QGridLayout, QLabel, QMainWindow,
QMessageBox, QPlainTextEdit, QProgressDialog,
QPushButton, QVBoxLayout, QWidget)
from requests import get
from whisper import tokenizer
@ -270,16 +270,18 @@ class FileTranscriberObject(QObject):
transcriber: FileTranscriber
def __init__(
self, model_name: str, use_whisper_cpp: bool, language: Optional[str],
task: Task, file_path: str, output_file_path: str,
output_format: OutputFormat, parent: Optional['QObject'], *args) -> None:
self, model_name: str, use_whisper_cpp: bool, language: Optional[str],
task: Task, file_path: str, output_file_path: str,
output_format: OutputFormat, word_level_timings: bool,
parent: Optional['QObject'], *args) -> None:
super().__init__(parent, *args)
self.transcriber = FileTranscriber(
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
on_download_model_chunk=self.on_download_model_progress,
language=language, task=task, file_path=file_path,
output_file_path=output_file_path, output_format=output_format,
event_callback=self.on_file_transcriber_event)
event_callback=self.on_file_transcriber_event,
word_level_timings=word_level_timings)
def on_download_model_progress(self, current: int, total: int):
self.download_model_progress.emit((current, total))
@ -380,6 +382,7 @@ class FileTranscriberWidget(QWidget):
selected_language: Optional[str] = None
selected_task = Task.TRANSCRIBE
selected_output_format = OutputFormat.TXT
enabled_word_level_timings = False
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None
file_transcriber: Optional[FileTranscriberObject] = None
@ -418,6 +421,11 @@ class FileTranscriberWidget(QWidget):
output_formats_combo_box.output_format_changed.connect(
self.on_output_format_changed)
self.word_level_timings_checkbox = QCheckBox('Word-level timings')
self.word_level_timings_checkbox.stateChanged.connect(
self.on_word_level_timings_changed)
self.word_level_timings_checkbox.setDisabled(True)
grid = (
((0, 5, FormLabel('Task:', parent=self)), (5, 7, self.tasks_combo_box)),
((0, 5, FormLabel('Language:', parent=self)),
@ -426,6 +434,7 @@ class FileTranscriberWidget(QWidget):
(5, 7, self.quality_combo_box)),
((0, 5, FormLabel('Export As:', self)),
(5, 7, output_formats_combo_box)),
((5, 7, self.word_level_timings_checkbox),),
((9, 3, self.run_button),)
)
@ -447,6 +456,8 @@ class FileTranscriberWidget(QWidget):
def on_output_format_changed(self, output_format: OutputFormat):
self.selected_output_format = output_format
self.word_level_timings_checkbox.setDisabled(
output_format == OutputFormat.TXT)
def on_click_run(self):
default_path = FileTranscriber.get_default_output_file_path(
@ -469,6 +480,7 @@ class FileTranscriberWidget(QWidget):
file_path=self.file_path,
language=self.selected_language, task=self.selected_task,
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
parent=self)
self.file_transcriber.download_model_progress.connect(
self.on_download_model_progress)
@ -530,6 +542,9 @@ class FileTranscriberWidget(QWidget):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog = None
def on_word_level_timings_changed(self, value: int):
self.enabled_word_level_timings = value == Qt.CheckState.Checked.value
class Settings(QSettings):
ENABLE_GGML_INFERENCE = 'enable_ggml_inference'
@ -829,7 +844,7 @@ class FileTranscriberMainWindow(MainWindow):
def __init__(self, file_path: str, parent: Optional[QWidget], *args) -> None:
super().__init__(title=get_short_file_path(
file_path), w=400, h=210, parent=parent, *args)
file_path), w=400, h=240, parent=parent, *args)
self.central_widget = FileTranscriberWidget(file_path, self)
self.central_widget.setContentsMargins(10, 10, 10, 10)

18
poetry.lock generated
View file

@ -1176,6 +1176,9 @@ PyQt6-sip = [
{file = "PyQt6_sip-13.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3ac7e0800180202dcc0c7035ff88c2a6f4a0f5acb20c4a19f71d807d0f7857b7"},
{file = "PyQt6_sip-13.4.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bb4f2e2fdcf3a8dafe4256750bbedd9e7107c4fd8afa9c25be28423c36bb12b8"},
{file = "PyQt6_sip-13.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:de601187055d684b36ebe6e800a5deacaa55b69d71ad43312b76422cfeae0e12"},
{file = "PyQt6_sip-13.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e3b17308ca729bcb6d25c01144c6b2e17d40812231c3ef9caaa72a78db2b1069"},
{file = "PyQt6_sip-13.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d51704d50b82713fd7c928b7deb31e17be239ddac74fc2fd708e52bd21ecea3a"},
{file = "PyQt6_sip-13.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:77af9c7e3f50414ec5af9b1534aaf2ba25115ae65aa5ed735111c8ef0884b862"},
{file = "PyQt6_sip-13.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:83b446d247a92d119d507dbc94fc1f47389d8118a5b6232a2859951157319a30"},
{file = "PyQt6_sip-13.4.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:802b0cfed19900183220c46895c2635f0dd062f2d275a25506423f911ef74db4"},
{file = "PyQt6_sip-13.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2694ae67811cefb6ea3ee0e9995755b45e4952f4dcadec8c04300fd828f91c75"},
@ -1365,21 +1368,36 @@ sounddevice = [
tokenizers = [
{file = "tokenizers-0.13.2-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:a6f36b1b499233bb4443b5e57e20630c5e02fba61109632f5e00dab970440157"},
{file = "tokenizers-0.13.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:bc6983282ee74d638a4b6d149e5dadd8bc7ff1d0d6de663d69f099e0c6bddbeb"},
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16756e6ab264b162f99c0c0a8d3d521328f428b33374c5ee161c0ebec42bf3c0"},
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b10db6e4b036c78212c6763cb56411566edcf2668c910baa1939afd50095ce48"},
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:238e879d1a0f4fddc2ce5b2d00f219125df08f8532e5f1f2ba9ad42f02b7da59"},
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"},
{file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"},
{file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"},
{file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"},
{file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0901a5c6538d2d2dc752c6b4bde7dab170fddce559ec75662cfad03b3187c8f6"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ba9baa76b5a3eefa78b6cc351315a216232fd727ee5e3ce0f7c6885d9fb531b"},
{file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"},
{file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"},
{file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3bc9f7d7f4c1aa84bb6b8d642a60272c8a2c987669e9bb0ac26daf0c6a9fc8"},
{file = "tokenizers-0.13.2-cp38-cp38-win32.whl", hash = "sha256:efbf189fb9cf29bd29e98c0437bdb9809f9de686a1e6c10e0b954410e9ca2142"},
{file = "tokenizers-0.13.2-cp38-cp38-win_amd64.whl", hash = "sha256:0b4cb2c60c094f31ea652f6cf9f349aae815f9243b860610c29a69ed0d7a88f8"},
{file = "tokenizers-0.13.2-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:b47d6212e7dd05784d7330b3b1e5a170809fa30e2b333ca5c93fba1463dec2b7"},
{file = "tokenizers-0.13.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:80a57501b61ec4f94fb7ce109e2b4a1a090352618efde87253b4ede6d458b605"},
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61507a9953f6e7dc3c972cbc57ba94c80c8f7f686fbc0876afe70ea2b8cc8b04"},
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c09f4fa620e879debdd1ec299bb81e3c961fd8f64f0e460e64df0818d29d845c"},
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:66c892d85385b202893ac6bc47b13390909e205280e5df89a41086cfec76fedb"},
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e306b0941ad35087ae7083919a5c410a6b672be0343609d79a1171a364ce79"},
{file = "tokenizers-0.13.2-cp39-cp39-win32.whl", hash = "sha256:79189e7f706c74dbc6b753450757af172240916d6a12ed4526af5cc6d3ceca26"},
{file = "tokenizers-0.13.2-cp39-cp39-win_amd64.whl", hash = "sha256:486d637b413fddada845a10a45c74825d73d3725da42ccd8796ccd7a1c07a024"},

View file

@ -9,7 +9,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.9.13,<3.11"
sounddevice = "^0.4.5"
whisper = {git = "https://github.com/openai/whisper.git"}
whisper = { git = "https://github.com/openai/whisper.git" }
torch = "1.12.1"
numpy = "^1.23.3"
transformers = "^4.22.1"

1
stable_ts Submodule

@ -0,0 +1 @@
Subproject commit 5c1966392b0a67b79a454d3f902772a2d9085d77

View file

@ -18,6 +18,7 @@ import whisper
from sounddevice import PortAudioError
from conn import pipe_stderr, pipe_stdout
from stable_ts.stable_whisper import group_word_timestamps, modify_model
from whispr import (ModelLoader, Segment, Stopped, Task, WhisperCpp,
read_progress, whisper_cpp_params)
@ -108,7 +109,7 @@ class RecordingTranscriber:
audio=samples,
params=whisper_cpp_params(
language=self.language if self.language is not None else 'en',
task=self.task.value))
task=self.task.value, word_level_timings=False))
next_text: str = result.get('text')
@ -236,6 +237,7 @@ class FileTranscriber:
model_name: str, use_whisper_cpp: bool,
language: Optional[str], task: Task, file_path: str,
output_file_path: str, output_format: OutputFormat,
word_level_timings: bool,
event_callback: Callable[[Event], None] = lambda *_: None,
on_download_model_chunk: Callable[[
int, int], None] = lambda *_: None,
@ -246,6 +248,7 @@ class FileTranscriber:
self.task = task
self.open_file_on_complete = open_file_on_complete
self.output_format = output_format
self.word_level_timings = word_level_timings
self.model_name = model_name
self.use_whisper_cpp = use_whisper_cpp
@ -290,6 +293,7 @@ class FileTranscriber:
self.output_format,
self.language if self.language is not None else 'en',
self.task, True, True,
self.word_level_timings
))
else:
self.current_process = multiprocessing.Process(
@ -298,6 +302,7 @@ class FileTranscriber:
send_pipe, model_path, self.file_path,
self.language, self.task, self.output_file_path,
self.open_file_on_complete, self.output_format,
self.word_level_timings
))
self.current_process.start()
@ -352,18 +357,26 @@ class FileTranscriber:
def transcribe_whisper(
stderr_conn: Connection, model_path: str, file_path: str,
language: Optional[str], task: Task, output_file_path: str,
open_file_on_complete: bool, output_format: OutputFormat):
open_file_on_complete: bool, output_format: OutputFormat,
word_level_timings: bool):
with pipe_stderr(stderr_conn):
model = whisper.load_model(model_path)
result = whisper.transcribe(
model=model, audio=file_path, language=language, task=task.value, verbose=False)
if word_level_timings:
modify_model(model)
result = model.transcribe(
audio=file_path, language=language, task=task.value, verbose=False)
whisper_segments = group_word_timestamps(
result) if word_level_timings else result.get('segments')
segments = map(
lambda segment: Segment(
start=segment.get('start')*1000, # s to ms
end=segment.get('end')*1000, # s to ms
text=segment.get('text')),
result.get('segments'))
whisper_segments)
write_output(output_file_path, list(
segments), open_file_on_complete, output_format)
@ -372,13 +385,14 @@ def transcribe_whisper(
def transcribe_whisper_cpp(
stderr_conn: Connection, model_path: str, audio: typing.Union[np.ndarray, str],
output_file_path: str, open_file_on_complete: bool, output_format: OutputFormat,
language: str, task: Task, print_realtime: bool, print_progress: bool):
language: str, task: Task, print_realtime: bool, print_progress: bool,
word_level_timings: bool):
# TODO: capturing output does not work because ctypes functions
# See: https://stackoverflow.com/questions/9488560/capturing-print-output-from-shared-library-called-from-python-with-ctypes-module
with pipe_stdout(stderr_conn), pipe_stderr(stderr_conn):
model = WhisperCpp(model_path)
params = whisper_cpp_params(
language, task, print_realtime, print_progress)
language, task, word_level_timings, print_realtime, print_progress)
result = model.transcribe(audio=audio, params=params)
segments: List[Segment] = result.get('segments')
write_output(

View file

@ -1,4 +1,5 @@
import os
import pathlib
import tempfile
import pytest
@ -29,10 +30,16 @@ class TestFileTranscriber:
assert srt.startswith('/a/b/c (Translated on ')
assert srt.endswith('.srt')
def test_transcribe_whisper(self):
output_file_path = os.path.join(tempfile.gettempdir(), 'whisper.txt')
if os.path.exists(output_file_path):
os.remove(output_file_path)
@pytest.mark.parametrize(
'word_level_timings,output_format,output_text',
[
(False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle, un podcast'),
(False, OutputFormat.SRT, '1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'),
(False, OutputFormat.VTT, 'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'),
(True, OutputFormat.SRT, '1\n00:00:00.040 --> 00:00:00.059\n Bienvenue\n\n2\n00:00:00.059 --> 00:00:00.359\n dans P'),
])
def test_transcribe_whisper(self, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str):
output_file_path = tmp_path / f'whisper.{output_format.value.lower()}'
events = []
@ -42,15 +49,16 @@ class TestFileTranscriber:
transcriber = FileTranscriber(
model_name='tiny', use_whisper_cpp=False, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path, output_format=OutputFormat.TXT,
open_file_on_complete=False, event_callback=event_callback)
output_file_path=output_file_path.as_posix(), output_format=output_format,
open_file_on_complete=False, event_callback=event_callback,
word_level_timings=word_level_timings)
transcriber.start()
transcriber.join()
assert os.path.isfile(output_file_path)
output_file = open(output_file_path, 'r', encoding='utf-8')
assert 'Bienvenue dans Passe-Relle, un podcast' in output_file.read()
assert output_text in output_file.read()
# Reports progress at 0, 0<progress<100, and 100
assert len([event for event in events if isinstance(
@ -74,7 +82,8 @@ class TestFileTranscriber:
model_name='tiny', use_whisper_cpp=False, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path, output_format=OutputFormat.TXT,
open_file_on_complete=False, event_callback=event_callback)
open_file_on_complete=False, event_callback=event_callback,
word_level_timings=False)
transcriber.start()
transcriber.stop()
@ -98,7 +107,8 @@ class TestFileTranscriber:
model_name='tiny', use_whisper_cpp=True, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path, output_format=OutputFormat.TXT,
open_file_on_complete=False, event_callback=event_callback)
open_file_on_complete=False, event_callback=event_callback,
word_level_timings=False)
transcriber.start()
transcriber.join()

View file

@ -38,12 +38,17 @@ class Task(enum.Enum):
TRANSCRIBE = "transcribe"
def whisper_cpp_params(language: str, task: Task, print_realtime=False, print_progress=False):
params = whisper_cpp.whisper_full_default_params(0)
def whisper_cpp_params(
language: str, task: Task, word_level_timings: bool,
print_realtime=False, print_progress=False,):
params = whisper_cpp.whisper_full_default_params(whisper_cpp.WHISPER_SAMPLING_GREEDY)
params.print_realtime = print_realtime
params.print_progress = print_progress
params.language = whisper_cpp.String(language.encode('utf-8'))
params.translate = task == Task.TRANSLATE
params.max_len = ctypes.c_int(1)
params.max_len = 1 if word_level_timings else 0
params.token_timestamps = word_level_timings
return params
@ -85,7 +90,6 @@ class WhisperCpp:
whisper_cpp.whisper_free((self.ctx))
# TODO: should this instead subclass Process?
class ModelLoader:
process: multiprocessing.Process
model_path_queue: multiprocessing.Queue
@ -133,16 +137,12 @@ class ModelLoader:
return WhisperCpp(model_path) if self.use_whisper_cpp else whisper.load_model(model_path)
def load_whisper_cpp_model(self, stderr_conn: Connection, queue: multiprocessing.Queue, name: str):
path = download_whisper_cpp_model(name)
path = download_model(name, use_whisper_cpp=True)
queue.put(path)
def load_whisper_model(self, stderr_conn: Connection, queue: multiprocessing.Queue, name: str):
with pipe_stderr(stderr_conn):
download_root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
path = download_whisper_model(whisper._MODELS[name], download_root)
path = download_model(name, use_whisper_cpp=False)
queue.put(path)
def stop(self):
@ -153,43 +153,6 @@ class ModelLoader:
return self.process.is_alive()
def download_whisper_model(url: str, root: str):
"""See whisper._download"""
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(
f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(download_target, 'wb') as output:
source.raise_for_status()
total_size = int(source.headers.get('Content-Length', 0))
with tqdm(total=total_size, ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
for chunk in source.iter_content(chunk_size=8192):
output.write(chunk)
loop.update(len(chunk))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
return download_target
MODELS_SHA256 = {
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
@ -197,33 +160,39 @@ MODELS_SHA256 = {
}
def download_whisper_cpp_model(name: str):
"""Downloads a Whisper.cpp GGML model to the user cache directory."""
def download_model(name: str, use_whisper_cpp=False):
if use_whisper_cpp:
root = user_cache_dir('Buzz')
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{name}.bin'
else:
root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
url = whisper._MODELS[name]
base_dir = user_cache_dir('Buzz')
os.makedirs(base_dir, exist_ok=True)
os.makedirs(root, exist_ok=True)
model_path = os.path.join(
base_dir, f'ggml-model-whisper-{name}.bin')
model_path = os.path.join(root, os.path.basename(url))
if os.path.exists(model_path) and not os.path.isfile(model_path):
raise RuntimeError(
f"{model_path} exists and is not a regular file")
expected_sha256 = MODELS_SHA256[name]
expected_sha256 = MODELS_SHA256[name] if use_whisper_cpp else url.split(
"/")[-2]
if os.path.isfile(model_path):
model_bytes = open(model_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_path
else:
warnings.warn(
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
logging.debug(
'%s exists, but the SHA256 checksum does not match; re-downloading the file', model_path)
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{name}.bin'
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
source.raise_for_status()
total_size = int(source.headers.get('Content-Length', 0))
with tqdm(total=total_size, ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
for chunk in source.iter_content(chunk_size=8192):