mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-26 11:40:09 +02:00
Add stable timings (#159)
This commit is contained in:
parent
a1b9097133
commit
392f9cb469
11
.coveragerc
Normal file
11
.coveragerc
Normal file
|
@ -0,0 +1,11 @@
|
|||
[run]
|
||||
omit =
|
||||
whisper_cpp.py
|
||||
*_test.py
|
||||
stable_ts/*
|
||||
|
||||
[html]
|
||||
directory = coverage/html
|
||||
|
||||
[report]
|
||||
fail_under = 75
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -3,6 +3,7 @@ __pycache__/
|
|||
build/
|
||||
.pytest_cache/
|
||||
.coverage*
|
||||
!.coveragerc
|
||||
.env
|
||||
htmlcov/
|
||||
libwhisper.*
|
||||
|
|
4
.gitmodules
vendored
4
.gitmodules
vendored
|
@ -1,3 +1,7 @@
|
|||
[submodule "whisper.cpp"]
|
||||
path = whisper.cpp
|
||||
url = https://github.com/chidiwilliams/whisper.cpp
|
||||
[submodule "stable_ts"]
|
||||
path = stable_ts
|
||||
url = https://github.com/chidiwilliams/stable-ts
|
||||
branch = main
|
||||
|
|
2
Makefile
2
Makefile
|
@ -43,7 +43,7 @@ clean:
|
|||
rm -rf dist/* || true
|
||||
|
||||
test: whisper_cpp.py
|
||||
pytest --cov --cov-fail-under=69 --cov-report html
|
||||
pytest --cov
|
||||
|
||||
dist/Buzz: whisper_cpp.py
|
||||
pyinstaller --noconfirm Buzz.spec
|
||||
|
|
33
gui.py
33
gui.py
|
@ -12,10 +12,10 @@ from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QTimer,
|
|||
QUrl, pyqtSignal)
|
||||
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
||||
QKeySequence, QPixmap, QTextCursor)
|
||||
from PyQt6.QtWidgets import (QApplication, QComboBox, QDialog, QFileDialog,
|
||||
QGridLayout, QLabel, QMainWindow, QMessageBox,
|
||||
QPlainTextEdit, QProgressDialog, QPushButton,
|
||||
QVBoxLayout, QWidget)
|
||||
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
||||
QFileDialog, QGridLayout, QLabel, QMainWindow,
|
||||
QMessageBox, QPlainTextEdit, QProgressDialog,
|
||||
QPushButton, QVBoxLayout, QWidget)
|
||||
from requests import get
|
||||
from whisper import tokenizer
|
||||
|
||||
|
@ -270,16 +270,18 @@ class FileTranscriberObject(QObject):
|
|||
transcriber: FileTranscriber
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, use_whisper_cpp: bool, language: Optional[str],
|
||||
task: Task, file_path: str, output_file_path: str,
|
||||
output_format: OutputFormat, parent: Optional['QObject'], *args) -> None:
|
||||
self, model_name: str, use_whisper_cpp: bool, language: Optional[str],
|
||||
task: Task, file_path: str, output_file_path: str,
|
||||
output_format: OutputFormat, word_level_timings: bool,
|
||||
parent: Optional['QObject'], *args) -> None:
|
||||
super().__init__(parent, *args)
|
||||
self.transcriber = FileTranscriber(
|
||||
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
|
||||
on_download_model_chunk=self.on_download_model_progress,
|
||||
language=language, task=task, file_path=file_path,
|
||||
output_file_path=output_file_path, output_format=output_format,
|
||||
event_callback=self.on_file_transcriber_event)
|
||||
event_callback=self.on_file_transcriber_event,
|
||||
word_level_timings=word_level_timings)
|
||||
|
||||
def on_download_model_progress(self, current: int, total: int):
|
||||
self.download_model_progress.emit((current, total))
|
||||
|
@ -380,6 +382,7 @@ class FileTranscriberWidget(QWidget):
|
|||
selected_language: Optional[str] = None
|
||||
selected_task = Task.TRANSCRIBE
|
||||
selected_output_format = OutputFormat.TXT
|
||||
enabled_word_level_timings = False
|
||||
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None
|
||||
file_transcriber: Optional[FileTranscriberObject] = None
|
||||
|
@ -418,6 +421,11 @@ class FileTranscriberWidget(QWidget):
|
|||
output_formats_combo_box.output_format_changed.connect(
|
||||
self.on_output_format_changed)
|
||||
|
||||
self.word_level_timings_checkbox = QCheckBox('Word-level timings')
|
||||
self.word_level_timings_checkbox.stateChanged.connect(
|
||||
self.on_word_level_timings_changed)
|
||||
self.word_level_timings_checkbox.setDisabled(True)
|
||||
|
||||
grid = (
|
||||
((0, 5, FormLabel('Task:', parent=self)), (5, 7, self.tasks_combo_box)),
|
||||
((0, 5, FormLabel('Language:', parent=self)),
|
||||
|
@ -426,6 +434,7 @@ class FileTranscriberWidget(QWidget):
|
|||
(5, 7, self.quality_combo_box)),
|
||||
((0, 5, FormLabel('Export As:', self)),
|
||||
(5, 7, output_formats_combo_box)),
|
||||
((5, 7, self.word_level_timings_checkbox),),
|
||||
((9, 3, self.run_button),)
|
||||
)
|
||||
|
||||
|
@ -447,6 +456,8 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
def on_output_format_changed(self, output_format: OutputFormat):
|
||||
self.selected_output_format = output_format
|
||||
self.word_level_timings_checkbox.setDisabled(
|
||||
output_format == OutputFormat.TXT)
|
||||
|
||||
def on_click_run(self):
|
||||
default_path = FileTranscriber.get_default_output_file_path(
|
||||
|
@ -469,6 +480,7 @@ class FileTranscriberWidget(QWidget):
|
|||
file_path=self.file_path,
|
||||
language=self.selected_language, task=self.selected_task,
|
||||
output_file_path=output_file, output_format=self.selected_output_format,
|
||||
word_level_timings=self.enabled_word_level_timings,
|
||||
parent=self)
|
||||
self.file_transcriber.download_model_progress.connect(
|
||||
self.on_download_model_progress)
|
||||
|
@ -530,6 +542,9 @@ class FileTranscriberWidget(QWidget):
|
|||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
def on_word_level_timings_changed(self, value: int):
|
||||
self.enabled_word_level_timings = value == Qt.CheckState.Checked.value
|
||||
|
||||
|
||||
class Settings(QSettings):
|
||||
ENABLE_GGML_INFERENCE = 'enable_ggml_inference'
|
||||
|
@ -829,7 +844,7 @@ class FileTranscriberMainWindow(MainWindow):
|
|||
|
||||
def __init__(self, file_path: str, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(title=get_short_file_path(
|
||||
file_path), w=400, h=210, parent=parent, *args)
|
||||
file_path), w=400, h=240, parent=parent, *args)
|
||||
|
||||
self.central_widget = FileTranscriberWidget(file_path, self)
|
||||
self.central_widget.setContentsMargins(10, 10, 10, 10)
|
||||
|
|
18
poetry.lock
generated
18
poetry.lock
generated
|
@ -1176,6 +1176,9 @@ PyQt6-sip = [
|
|||
{file = "PyQt6_sip-13.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3ac7e0800180202dcc0c7035ff88c2a6f4a0f5acb20c4a19f71d807d0f7857b7"},
|
||||
{file = "PyQt6_sip-13.4.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bb4f2e2fdcf3a8dafe4256750bbedd9e7107c4fd8afa9c25be28423c36bb12b8"},
|
||||
{file = "PyQt6_sip-13.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:de601187055d684b36ebe6e800a5deacaa55b69d71ad43312b76422cfeae0e12"},
|
||||
{file = "PyQt6_sip-13.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e3b17308ca729bcb6d25c01144c6b2e17d40812231c3ef9caaa72a78db2b1069"},
|
||||
{file = "PyQt6_sip-13.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d51704d50b82713fd7c928b7deb31e17be239ddac74fc2fd708e52bd21ecea3a"},
|
||||
{file = "PyQt6_sip-13.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:77af9c7e3f50414ec5af9b1534aaf2ba25115ae65aa5ed735111c8ef0884b862"},
|
||||
{file = "PyQt6_sip-13.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:83b446d247a92d119d507dbc94fc1f47389d8118a5b6232a2859951157319a30"},
|
||||
{file = "PyQt6_sip-13.4.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:802b0cfed19900183220c46895c2635f0dd062f2d275a25506423f911ef74db4"},
|
||||
{file = "PyQt6_sip-13.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2694ae67811cefb6ea3ee0e9995755b45e4952f4dcadec8c04300fd828f91c75"},
|
||||
|
@ -1365,21 +1368,36 @@ sounddevice = [
|
|||
tokenizers = [
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:a6f36b1b499233bb4443b5e57e20630c5e02fba61109632f5e00dab970440157"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:bc6983282ee74d638a4b6d149e5dadd8bc7ff1d0d6de663d69f099e0c6bddbeb"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16756e6ab264b162f99c0c0a8d3d521328f428b33374c5ee161c0ebec42bf3c0"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b10db6e4b036c78212c6763cb56411566edcf2668c910baa1939afd50095ce48"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:238e879d1a0f4fddc2ce5b2d00f219125df08f8532e5f1f2ba9ad42f02b7da59"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0901a5c6538d2d2dc752c6b4bde7dab170fddce559ec75662cfad03b3187c8f6"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ba9baa76b5a3eefa78b6cc351315a216232fd727ee5e3ce0f7c6885d9fb531b"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3bc9f7d7f4c1aa84bb6b8d642a60272c8a2c987669e9bb0ac26daf0c6a9fc8"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-win32.whl", hash = "sha256:efbf189fb9cf29bd29e98c0437bdb9809f9de686a1e6c10e0b954410e9ca2142"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-win_amd64.whl", hash = "sha256:0b4cb2c60c094f31ea652f6cf9f349aae815f9243b860610c29a69ed0d7a88f8"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:b47d6212e7dd05784d7330b3b1e5a170809fa30e2b333ca5c93fba1463dec2b7"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:80a57501b61ec4f94fb7ce109e2b4a1a090352618efde87253b4ede6d458b605"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61507a9953f6e7dc3c972cbc57ba94c80c8f7f686fbc0876afe70ea2b8cc8b04"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c09f4fa620e879debdd1ec299bb81e3c961fd8f64f0e460e64df0818d29d845c"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:66c892d85385b202893ac6bc47b13390909e205280e5df89a41086cfec76fedb"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e306b0941ad35087ae7083919a5c410a6b672be0343609d79a1171a364ce79"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-win32.whl", hash = "sha256:79189e7f706c74dbc6b753450757af172240916d6a12ed4526af5cc6d3ceca26"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-win_amd64.whl", hash = "sha256:486d637b413fddada845a10a45c74825d73d3725da42ccd8796ccd7a1c07a024"},
|
||||
|
|
|
@ -9,7 +9,7 @@ readme = "README.md"
|
|||
[tool.poetry.dependencies]
|
||||
python = ">=3.9.13,<3.11"
|
||||
sounddevice = "^0.4.5"
|
||||
whisper = {git = "https://github.com/openai/whisper.git"}
|
||||
whisper = { git = "https://github.com/openai/whisper.git" }
|
||||
torch = "1.12.1"
|
||||
numpy = "^1.23.3"
|
||||
transformers = "^4.22.1"
|
||||
|
|
1
stable_ts
Submodule
1
stable_ts
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 5c1966392b0a67b79a454d3f902772a2d9085d77
|
|
@ -18,6 +18,7 @@ import whisper
|
|||
from sounddevice import PortAudioError
|
||||
|
||||
from conn import pipe_stderr, pipe_stdout
|
||||
from stable_ts.stable_whisper import group_word_timestamps, modify_model
|
||||
from whispr import (ModelLoader, Segment, Stopped, Task, WhisperCpp,
|
||||
read_progress, whisper_cpp_params)
|
||||
|
||||
|
@ -108,7 +109,7 @@ class RecordingTranscriber:
|
|||
audio=samples,
|
||||
params=whisper_cpp_params(
|
||||
language=self.language if self.language is not None else 'en',
|
||||
task=self.task.value))
|
||||
task=self.task.value, word_level_timings=False))
|
||||
|
||||
next_text: str = result.get('text')
|
||||
|
||||
|
@ -236,6 +237,7 @@ class FileTranscriber:
|
|||
model_name: str, use_whisper_cpp: bool,
|
||||
language: Optional[str], task: Task, file_path: str,
|
||||
output_file_path: str, output_format: OutputFormat,
|
||||
word_level_timings: bool,
|
||||
event_callback: Callable[[Event], None] = lambda *_: None,
|
||||
on_download_model_chunk: Callable[[
|
||||
int, int], None] = lambda *_: None,
|
||||
|
@ -246,6 +248,7 @@ class FileTranscriber:
|
|||
self.task = task
|
||||
self.open_file_on_complete = open_file_on_complete
|
||||
self.output_format = output_format
|
||||
self.word_level_timings = word_level_timings
|
||||
|
||||
self.model_name = model_name
|
||||
self.use_whisper_cpp = use_whisper_cpp
|
||||
|
@ -290,6 +293,7 @@ class FileTranscriber:
|
|||
self.output_format,
|
||||
self.language if self.language is not None else 'en',
|
||||
self.task, True, True,
|
||||
self.word_level_timings
|
||||
))
|
||||
else:
|
||||
self.current_process = multiprocessing.Process(
|
||||
|
@ -298,6 +302,7 @@ class FileTranscriber:
|
|||
send_pipe, model_path, self.file_path,
|
||||
self.language, self.task, self.output_file_path,
|
||||
self.open_file_on_complete, self.output_format,
|
||||
self.word_level_timings
|
||||
))
|
||||
|
||||
self.current_process.start()
|
||||
|
@ -352,18 +357,26 @@ class FileTranscriber:
|
|||
def transcribe_whisper(
|
||||
stderr_conn: Connection, model_path: str, file_path: str,
|
||||
language: Optional[str], task: Task, output_file_path: str,
|
||||
open_file_on_complete: bool, output_format: OutputFormat):
|
||||
open_file_on_complete: bool, output_format: OutputFormat,
|
||||
word_level_timings: bool):
|
||||
with pipe_stderr(stderr_conn):
|
||||
model = whisper.load_model(model_path)
|
||||
result = whisper.transcribe(
|
||||
model=model, audio=file_path, language=language, task=task.value, verbose=False)
|
||||
|
||||
if word_level_timings:
|
||||
modify_model(model)
|
||||
|
||||
result = model.transcribe(
|
||||
audio=file_path, language=language, task=task.value, verbose=False)
|
||||
|
||||
whisper_segments = group_word_timestamps(
|
||||
result) if word_level_timings else result.get('segments')
|
||||
|
||||
segments = map(
|
||||
lambda segment: Segment(
|
||||
start=segment.get('start')*1000, # s to ms
|
||||
end=segment.get('end')*1000, # s to ms
|
||||
text=segment.get('text')),
|
||||
result.get('segments'))
|
||||
whisper_segments)
|
||||
|
||||
write_output(output_file_path, list(
|
||||
segments), open_file_on_complete, output_format)
|
||||
|
@ -372,13 +385,14 @@ def transcribe_whisper(
|
|||
def transcribe_whisper_cpp(
|
||||
stderr_conn: Connection, model_path: str, audio: typing.Union[np.ndarray, str],
|
||||
output_file_path: str, open_file_on_complete: bool, output_format: OutputFormat,
|
||||
language: str, task: Task, print_realtime: bool, print_progress: bool):
|
||||
language: str, task: Task, print_realtime: bool, print_progress: bool,
|
||||
word_level_timings: bool):
|
||||
# TODO: capturing output does not work because ctypes functions
|
||||
# See: https://stackoverflow.com/questions/9488560/capturing-print-output-from-shared-library-called-from-python-with-ctypes-module
|
||||
with pipe_stdout(stderr_conn), pipe_stderr(stderr_conn):
|
||||
model = WhisperCpp(model_path)
|
||||
params = whisper_cpp_params(
|
||||
language, task, print_realtime, print_progress)
|
||||
language, task, word_level_timings, print_realtime, print_progress)
|
||||
result = model.transcribe(audio=audio, params=params)
|
||||
segments: List[Segment] = result.get('segments')
|
||||
write_output(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
@ -29,10 +30,16 @@ class TestFileTranscriber:
|
|||
assert srt.startswith('/a/b/c (Translated on ')
|
||||
assert srt.endswith('.srt')
|
||||
|
||||
def test_transcribe_whisper(self):
|
||||
output_file_path = os.path.join(tempfile.gettempdir(), 'whisper.txt')
|
||||
if os.path.exists(output_file_path):
|
||||
os.remove(output_file_path)
|
||||
@pytest.mark.parametrize(
|
||||
'word_level_timings,output_format,output_text',
|
||||
[
|
||||
(False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle, un podcast'),
|
||||
(False, OutputFormat.SRT, '1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'),
|
||||
(False, OutputFormat.VTT, 'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'),
|
||||
(True, OutputFormat.SRT, '1\n00:00:00.040 --> 00:00:00.059\n Bienvenue\n\n2\n00:00:00.059 --> 00:00:00.359\n dans P'),
|
||||
])
|
||||
def test_transcribe_whisper(self, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str):
|
||||
output_file_path = tmp_path / f'whisper.{output_format.value.lower()}'
|
||||
|
||||
events = []
|
||||
|
||||
|
@ -42,15 +49,16 @@ class TestFileTranscriber:
|
|||
transcriber = FileTranscriber(
|
||||
model_name='tiny', use_whisper_cpp=False, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False, event_callback=event_callback)
|
||||
output_file_path=output_file_path.as_posix(), output_format=output_format,
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
word_level_timings=word_level_timings)
|
||||
transcriber.start()
|
||||
transcriber.join()
|
||||
|
||||
assert os.path.isfile(output_file_path)
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert 'Bienvenue dans Passe-Relle, un podcast' in output_file.read()
|
||||
assert output_text in output_file.read()
|
||||
|
||||
# Reports progress at 0, 0<progress<100, and 100
|
||||
assert len([event for event in events if isinstance(
|
||||
|
@ -74,7 +82,8 @@ class TestFileTranscriber:
|
|||
model_name='tiny', use_whisper_cpp=False, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False, event_callback=event_callback)
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
word_level_timings=False)
|
||||
transcriber.start()
|
||||
transcriber.stop()
|
||||
|
||||
|
@ -98,7 +107,8 @@ class TestFileTranscriber:
|
|||
model_name='tiny', use_whisper_cpp=True, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False, event_callback=event_callback)
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
word_level_timings=False)
|
||||
transcriber.start()
|
||||
transcriber.join()
|
||||
|
||||
|
|
87
whispr.py
87
whispr.py
|
@ -38,12 +38,17 @@ class Task(enum.Enum):
|
|||
TRANSCRIBE = "transcribe"
|
||||
|
||||
|
||||
def whisper_cpp_params(language: str, task: Task, print_realtime=False, print_progress=False):
|
||||
params = whisper_cpp.whisper_full_default_params(0)
|
||||
def whisper_cpp_params(
|
||||
language: str, task: Task, word_level_timings: bool,
|
||||
print_realtime=False, print_progress=False,):
|
||||
params = whisper_cpp.whisper_full_default_params(whisper_cpp.WHISPER_SAMPLING_GREEDY)
|
||||
params.print_realtime = print_realtime
|
||||
params.print_progress = print_progress
|
||||
params.language = whisper_cpp.String(language.encode('utf-8'))
|
||||
params.translate = task == Task.TRANSLATE
|
||||
params.max_len = ctypes.c_int(1)
|
||||
params.max_len = 1 if word_level_timings else 0
|
||||
params.token_timestamps = word_level_timings
|
||||
return params
|
||||
|
||||
|
||||
|
@ -85,7 +90,6 @@ class WhisperCpp:
|
|||
whisper_cpp.whisper_free((self.ctx))
|
||||
|
||||
|
||||
# TODO: should this instead subclass Process?
|
||||
class ModelLoader:
|
||||
process: multiprocessing.Process
|
||||
model_path_queue: multiprocessing.Queue
|
||||
|
@ -133,16 +137,12 @@ class ModelLoader:
|
|||
return WhisperCpp(model_path) if self.use_whisper_cpp else whisper.load_model(model_path)
|
||||
|
||||
def load_whisper_cpp_model(self, stderr_conn: Connection, queue: multiprocessing.Queue, name: str):
|
||||
path = download_whisper_cpp_model(name)
|
||||
path = download_model(name, use_whisper_cpp=True)
|
||||
queue.put(path)
|
||||
|
||||
def load_whisper_model(self, stderr_conn: Connection, queue: multiprocessing.Queue, name: str):
|
||||
with pipe_stderr(stderr_conn):
|
||||
download_root = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
path = download_whisper_model(whisper._MODELS[name], download_root)
|
||||
path = download_model(name, use_whisper_cpp=False)
|
||||
queue.put(path)
|
||||
|
||||
def stop(self):
|
||||
|
@ -153,43 +153,6 @@ class ModelLoader:
|
|||
return self.process.is_alive()
|
||||
|
||||
|
||||
def download_whisper_model(url: str, root: str):
|
||||
"""See whisper._download"""
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(
|
||||
f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(download_target, 'wb') as output:
|
||||
source.raise_for_status()
|
||||
total_size = int(source.headers.get('Content-Length', 0))
|
||||
with tqdm(total=total_size, ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
output.write(chunk)
|
||||
loop.update(len(chunk))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
||||
|
||||
return download_target
|
||||
|
||||
|
||||
MODELS_SHA256 = {
|
||||
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
|
||||
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
|
||||
|
@ -197,33 +160,39 @@ MODELS_SHA256 = {
|
|||
}
|
||||
|
||||
|
||||
def download_whisper_cpp_model(name: str):
|
||||
"""Downloads a Whisper.cpp GGML model to the user cache directory."""
|
||||
def download_model(name: str, use_whisper_cpp=False):
|
||||
if use_whisper_cpp:
|
||||
root = user_cache_dir('Buzz')
|
||||
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{name}.bin'
|
||||
else:
|
||||
root = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
url = whisper._MODELS[name]
|
||||
|
||||
base_dir = user_cache_dir('Buzz')
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
model_path = os.path.join(
|
||||
base_dir, f'ggml-model-whisper-{name}.bin')
|
||||
model_path = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(model_path) and not os.path.isfile(model_path):
|
||||
raise RuntimeError(
|
||||
f"{model_path} exists and is not a regular file")
|
||||
|
||||
expected_sha256 = MODELS_SHA256[name]
|
||||
|
||||
expected_sha256 = MODELS_SHA256[name] if use_whisper_cpp else url.split(
|
||||
"/")[-2]
|
||||
if os.path.isfile(model_path):
|
||||
model_bytes = open(model_path, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_path
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
logging.debug(
|
||||
'%s exists, but the SHA256 checksum does not match; re-downloading the file', model_path)
|
||||
|
||||
url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{name}.bin'
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
|
||||
source.raise_for_status()
|
||||
|
||||
total_size = int(source.headers.get('Content-Length', 0))
|
||||
with tqdm(total=total_size, ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
|
|
Loading…
Reference in a new issue