diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..f11f6af --- /dev/null +++ b/.coveragerc @@ -0,0 +1,11 @@ +[run] +omit = + whisper_cpp.py + *_test.py + stable_ts/* + +[html] +directory = coverage/html + +[report] +fail_under = 75 diff --git a/.gitignore b/.gitignore index bbf4160..b27d844 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ build/ .pytest_cache/ .coverage* +!.coveragerc .env htmlcov/ libwhisper.* diff --git a/.gitmodules b/.gitmodules index b67c5d9..9d6f4d6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,7 @@ [submodule "whisper.cpp"] path = whisper.cpp url = https://github.com/chidiwilliams/whisper.cpp +[submodule "stable_ts"] + path = stable_ts + url = https://github.com/chidiwilliams/stable-ts + branch = main diff --git a/Makefile b/Makefile index 49e46ad..a3366dc 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ clean: rm -rf dist/* || true test: whisper_cpp.py - pytest --cov --cov-fail-under=69 --cov-report html + pytest --cov dist/Buzz: whisper_cpp.py pyinstaller --noconfirm Buzz.spec diff --git a/gui.py b/gui.py index dfeafce..95c172d 100644 --- a/gui.py +++ b/gui.py @@ -12,10 +12,10 @@ from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QTimer, QUrl, pyqtSignal) from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon, QKeySequence, QPixmap, QTextCursor) -from PyQt6.QtWidgets import (QApplication, QComboBox, QDialog, QFileDialog, - QGridLayout, QLabel, QMainWindow, QMessageBox, - QPlainTextEdit, QProgressDialog, QPushButton, - QVBoxLayout, QWidget) +from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog, + QFileDialog, QGridLayout, QLabel, QMainWindow, + QMessageBox, QPlainTextEdit, QProgressDialog, + QPushButton, QVBoxLayout, QWidget) from requests import get from whisper import tokenizer @@ -270,16 +270,18 @@ class FileTranscriberObject(QObject): transcriber: FileTranscriber def __init__( - self, model_name: str, use_whisper_cpp: bool, language: Optional[str], - task: Task, file_path: str, output_file_path: str, - output_format: OutputFormat, parent: Optional['QObject'], *args) -> None: + self, model_name: str, use_whisper_cpp: bool, language: Optional[str], + task: Task, file_path: str, output_file_path: str, + output_format: OutputFormat, word_level_timings: bool, + parent: Optional['QObject'], *args) -> None: super().__init__(parent, *args) self.transcriber = FileTranscriber( model_name=model_name, use_whisper_cpp=use_whisper_cpp, on_download_model_chunk=self.on_download_model_progress, language=language, task=task, file_path=file_path, output_file_path=output_file_path, output_format=output_format, - event_callback=self.on_file_transcriber_event) + event_callback=self.on_file_transcriber_event, + word_level_timings=word_level_timings) def on_download_model_progress(self, current: int, total: int): self.download_model_progress.emit((current, total)) @@ -380,6 +382,7 @@ class FileTranscriberWidget(QWidget): selected_language: Optional[str] = None selected_task = Task.TRANSCRIBE selected_output_format = OutputFormat.TXT + enabled_word_level_timings = False model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None file_transcriber: Optional[FileTranscriberObject] = None @@ -418,6 +421,11 @@ class FileTranscriberWidget(QWidget): output_formats_combo_box.output_format_changed.connect( self.on_output_format_changed) + self.word_level_timings_checkbox = QCheckBox('Word-level timings') + self.word_level_timings_checkbox.stateChanged.connect( + self.on_word_level_timings_changed) + self.word_level_timings_checkbox.setDisabled(True) + grid = ( ((0, 5, FormLabel('Task:', parent=self)), (5, 7, self.tasks_combo_box)), ((0, 5, FormLabel('Language:', parent=self)), @@ -426,6 +434,7 @@ class FileTranscriberWidget(QWidget): (5, 7, self.quality_combo_box)), ((0, 5, FormLabel('Export As:', self)), (5, 7, output_formats_combo_box)), + ((5, 7, self.word_level_timings_checkbox),), ((9, 3, self.run_button),) ) @@ -447,6 +456,8 @@ class FileTranscriberWidget(QWidget): def on_output_format_changed(self, output_format: OutputFormat): self.selected_output_format = output_format + self.word_level_timings_checkbox.setDisabled( + output_format == OutputFormat.TXT) def on_click_run(self): default_path = FileTranscriber.get_default_output_file_path( @@ -469,6 +480,7 @@ class FileTranscriberWidget(QWidget): file_path=self.file_path, language=self.selected_language, task=self.selected_task, output_file_path=output_file, output_format=self.selected_output_format, + word_level_timings=self.enabled_word_level_timings, parent=self) self.file_transcriber.download_model_progress.connect( self.on_download_model_progress) @@ -530,6 +542,9 @@ class FileTranscriberWidget(QWidget): if self.model_download_progress_dialog is not None: self.model_download_progress_dialog = None + def on_word_level_timings_changed(self, value: int): + self.enabled_word_level_timings = value == Qt.CheckState.Checked.value + class Settings(QSettings): ENABLE_GGML_INFERENCE = 'enable_ggml_inference' @@ -829,7 +844,7 @@ class FileTranscriberMainWindow(MainWindow): def __init__(self, file_path: str, parent: Optional[QWidget], *args) -> None: super().__init__(title=get_short_file_path( - file_path), w=400, h=210, parent=parent, *args) + file_path), w=400, h=240, parent=parent, *args) self.central_widget = FileTranscriberWidget(file_path, self) self.central_widget.setContentsMargins(10, 10, 10, 10) diff --git a/poetry.lock b/poetry.lock index d6c268c..240dae5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1176,6 +1176,9 @@ PyQt6-sip = [ {file = "PyQt6_sip-13.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3ac7e0800180202dcc0c7035ff88c2a6f4a0f5acb20c4a19f71d807d0f7857b7"}, {file = "PyQt6_sip-13.4.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bb4f2e2fdcf3a8dafe4256750bbedd9e7107c4fd8afa9c25be28423c36bb12b8"}, {file = "PyQt6_sip-13.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:de601187055d684b36ebe6e800a5deacaa55b69d71ad43312b76422cfeae0e12"}, + {file = "PyQt6_sip-13.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e3b17308ca729bcb6d25c01144c6b2e17d40812231c3ef9caaa72a78db2b1069"}, + {file = "PyQt6_sip-13.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d51704d50b82713fd7c928b7deb31e17be239ddac74fc2fd708e52bd21ecea3a"}, + {file = "PyQt6_sip-13.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:77af9c7e3f50414ec5af9b1534aaf2ba25115ae65aa5ed735111c8ef0884b862"}, {file = "PyQt6_sip-13.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:83b446d247a92d119d507dbc94fc1f47389d8118a5b6232a2859951157319a30"}, {file = "PyQt6_sip-13.4.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:802b0cfed19900183220c46895c2635f0dd062f2d275a25506423f911ef74db4"}, {file = "PyQt6_sip-13.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2694ae67811cefb6ea3ee0e9995755b45e4952f4dcadec8c04300fd828f91c75"}, @@ -1365,21 +1368,36 @@ sounddevice = [ tokenizers = [ {file = "tokenizers-0.13.2-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:a6f36b1b499233bb4443b5e57e20630c5e02fba61109632f5e00dab970440157"}, {file = "tokenizers-0.13.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:bc6983282ee74d638a4b6d149e5dadd8bc7ff1d0d6de663d69f099e0c6bddbeb"}, + {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16756e6ab264b162f99c0c0a8d3d521328f428b33374c5ee161c0ebec42bf3c0"}, + {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b10db6e4b036c78212c6763cb56411566edcf2668c910baa1939afd50095ce48"}, + {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:238e879d1a0f4fddc2ce5b2d00f219125df08f8532e5f1f2ba9ad42f02b7da59"}, {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"}, {file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"}, {file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"}, + {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"}, + {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"}, + {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"}, {file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"}, {file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"}, + {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"}, + {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"}, + {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0901a5c6538d2d2dc752c6b4bde7dab170fddce559ec75662cfad03b3187c8f6"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ba9baa76b5a3eefa78b6cc351315a216232fd727ee5e3ce0f7c6885d9fb531b"}, {file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"}, {file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"}, {file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"}, + {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"}, + {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"}, + {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3bc9f7d7f4c1aa84bb6b8d642a60272c8a2c987669e9bb0ac26daf0c6a9fc8"}, {file = "tokenizers-0.13.2-cp38-cp38-win32.whl", hash = "sha256:efbf189fb9cf29bd29e98c0437bdb9809f9de686a1e6c10e0b954410e9ca2142"}, {file = "tokenizers-0.13.2-cp38-cp38-win_amd64.whl", hash = "sha256:0b4cb2c60c094f31ea652f6cf9f349aae815f9243b860610c29a69ed0d7a88f8"}, {file = "tokenizers-0.13.2-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:b47d6212e7dd05784d7330b3b1e5a170809fa30e2b333ca5c93fba1463dec2b7"}, {file = "tokenizers-0.13.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:80a57501b61ec4f94fb7ce109e2b4a1a090352618efde87253b4ede6d458b605"}, + {file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61507a9953f6e7dc3c972cbc57ba94c80c8f7f686fbc0876afe70ea2b8cc8b04"}, + {file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c09f4fa620e879debdd1ec299bb81e3c961fd8f64f0e460e64df0818d29d845c"}, + {file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:66c892d85385b202893ac6bc47b13390909e205280e5df89a41086cfec76fedb"}, {file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e306b0941ad35087ae7083919a5c410a6b672be0343609d79a1171a364ce79"}, {file = "tokenizers-0.13.2-cp39-cp39-win32.whl", hash = "sha256:79189e7f706c74dbc6b753450757af172240916d6a12ed4526af5cc6d3ceca26"}, {file = "tokenizers-0.13.2-cp39-cp39-win_amd64.whl", hash = "sha256:486d637b413fddada845a10a45c74825d73d3725da42ccd8796ccd7a1c07a024"}, diff --git a/pyproject.toml b/pyproject.toml index 6fb787d..d7930c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.9.13,<3.11" sounddevice = "^0.4.5" -whisper = {git = "https://github.com/openai/whisper.git"} +whisper = { git = "https://github.com/openai/whisper.git" } torch = "1.12.1" numpy = "^1.23.3" transformers = "^4.22.1" diff --git a/stable_ts b/stable_ts new file mode 160000 index 0000000..5c19663 --- /dev/null +++ b/stable_ts @@ -0,0 +1 @@ +Subproject commit 5c1966392b0a67b79a454d3f902772a2d9085d77 diff --git a/transcriber.py b/transcriber.py index af0335f..5988acc 100644 --- a/transcriber.py +++ b/transcriber.py @@ -18,6 +18,7 @@ import whisper from sounddevice import PortAudioError from conn import pipe_stderr, pipe_stdout +from stable_ts.stable_whisper import group_word_timestamps, modify_model from whispr import (ModelLoader, Segment, Stopped, Task, WhisperCpp, read_progress, whisper_cpp_params) @@ -108,7 +109,7 @@ class RecordingTranscriber: audio=samples, params=whisper_cpp_params( language=self.language if self.language is not None else 'en', - task=self.task.value)) + task=self.task.value, word_level_timings=False)) next_text: str = result.get('text') @@ -236,6 +237,7 @@ class FileTranscriber: model_name: str, use_whisper_cpp: bool, language: Optional[str], task: Task, file_path: str, output_file_path: str, output_format: OutputFormat, + word_level_timings: bool, event_callback: Callable[[Event], None] = lambda *_: None, on_download_model_chunk: Callable[[ int, int], None] = lambda *_: None, @@ -246,6 +248,7 @@ class FileTranscriber: self.task = task self.open_file_on_complete = open_file_on_complete self.output_format = output_format + self.word_level_timings = word_level_timings self.model_name = model_name self.use_whisper_cpp = use_whisper_cpp @@ -290,6 +293,7 @@ class FileTranscriber: self.output_format, self.language if self.language is not None else 'en', self.task, True, True, + self.word_level_timings )) else: self.current_process = multiprocessing.Process( @@ -298,6 +302,7 @@ class FileTranscriber: send_pipe, model_path, self.file_path, self.language, self.task, self.output_file_path, self.open_file_on_complete, self.output_format, + self.word_level_timings )) self.current_process.start() @@ -352,18 +357,26 @@ class FileTranscriber: def transcribe_whisper( stderr_conn: Connection, model_path: str, file_path: str, language: Optional[str], task: Task, output_file_path: str, - open_file_on_complete: bool, output_format: OutputFormat): + open_file_on_complete: bool, output_format: OutputFormat, + word_level_timings: bool): with pipe_stderr(stderr_conn): model = whisper.load_model(model_path) - result = whisper.transcribe( - model=model, audio=file_path, language=language, task=task.value, verbose=False) + + if word_level_timings: + modify_model(model) + + result = model.transcribe( + audio=file_path, language=language, task=task.value, verbose=False) + + whisper_segments = group_word_timestamps( + result) if word_level_timings else result.get('segments') segments = map( lambda segment: Segment( start=segment.get('start')*1000, # s to ms end=segment.get('end')*1000, # s to ms text=segment.get('text')), - result.get('segments')) + whisper_segments) write_output(output_file_path, list( segments), open_file_on_complete, output_format) @@ -372,13 +385,14 @@ def transcribe_whisper( def transcribe_whisper_cpp( stderr_conn: Connection, model_path: str, audio: typing.Union[np.ndarray, str], output_file_path: str, open_file_on_complete: bool, output_format: OutputFormat, - language: str, task: Task, print_realtime: bool, print_progress: bool): + language: str, task: Task, print_realtime: bool, print_progress: bool, + word_level_timings: bool): # TODO: capturing output does not work because ctypes functions # See: https://stackoverflow.com/questions/9488560/capturing-print-output-from-shared-library-called-from-python-with-ctypes-module with pipe_stdout(stderr_conn), pipe_stderr(stderr_conn): model = WhisperCpp(model_path) params = whisper_cpp_params( - language, task, print_realtime, print_progress) + language, task, word_level_timings, print_realtime, print_progress) result = model.transcribe(audio=audio, params=params) segments: List[Segment] = result.get('segments') write_output( diff --git a/transcriber_test.py b/transcriber_test.py index e0aedb0..0c22bcb 100644 --- a/transcriber_test.py +++ b/transcriber_test.py @@ -1,4 +1,5 @@ import os +import pathlib import tempfile import pytest @@ -29,10 +30,16 @@ class TestFileTranscriber: assert srt.startswith('/a/b/c (Translated on ') assert srt.endswith('.srt') - def test_transcribe_whisper(self): - output_file_path = os.path.join(tempfile.gettempdir(), 'whisper.txt') - if os.path.exists(output_file_path): - os.remove(output_file_path) + @pytest.mark.parametrize( + 'word_level_timings,output_format,output_text', + [ + (False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle, un podcast'), + (False, OutputFormat.SRT, '1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'), + (False, OutputFormat.VTT, 'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle, un podcast pensé pour évêyer la curiosité des apprenances'), + (True, OutputFormat.SRT, '1\n00:00:00.040 --> 00:00:00.059\n Bienvenue\n\n2\n00:00:00.059 --> 00:00:00.359\n dans P'), + ]) + def test_transcribe_whisper(self, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str): + output_file_path = tmp_path / f'whisper.{output_format.value.lower()}' events = [] @@ -42,15 +49,16 @@ class TestFileTranscriber: transcriber = FileTranscriber( model_name='tiny', use_whisper_cpp=False, language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3', - output_file_path=output_file_path, output_format=OutputFormat.TXT, - open_file_on_complete=False, event_callback=event_callback) + output_file_path=output_file_path.as_posix(), output_format=output_format, + open_file_on_complete=False, event_callback=event_callback, + word_level_timings=word_level_timings) transcriber.start() transcriber.join() assert os.path.isfile(output_file_path) output_file = open(output_file_path, 'r', encoding='utf-8') - assert 'Bienvenue dans Passe-Relle, un podcast' in output_file.read() + assert output_text in output_file.read() # Reports progress at 0, 0