Put FileTranscriber in QRunnable (#203)

This commit is contained in:
Chidi Williams 2022-12-04 18:30:24 +00:00 committed by GitHub
parent fe2292c833
commit 209c0af3b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 255 additions and 162 deletions

View file

@ -7,4 +7,4 @@ omit =
directory = coverage/html
[report]
fail_under = 78
fail_under = 77

View file

@ -16,6 +16,7 @@ jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- os: macos-latest
@ -26,7 +27,7 @@ jobs:
submodules: recursive
- uses: actions/setup-python@v4
with:
python-version: '3.9.13'
python-version: '3.10.7'
- name: Install Poetry Action
uses: snok/install-poetry@v1.3.1
@ -64,6 +65,7 @@ jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- os: macos-latest

1
.gitignore vendored
View file

@ -7,5 +7,6 @@ build/
.env
htmlcov/
libwhisper.*
whisper_cpp
whisper.dll
whisper_cpp.py

View file

@ -2,4 +2,3 @@
disable=
C0114, # missing-module-docstring
C0116, # missing-function-docstring
C0115, # missing-class-docstring

View file

@ -1,6 +1,8 @@
{
"files.associations": {
"Buzz.spec": "python"
".coveragerc": "ini",
"Buzz.spec": "python",
"iosfwd": "cpp"
},
"files.exclude": {
"**/.git": true,

View file

@ -17,6 +17,7 @@ datas += copy_metadata('tokenizers')
datas += collect_data_files('whisper')
datas += [('whisper.dll' if platform.system() ==
'Windows' else 'libwhisper.*', '.')]
datas += [('whisper_cpp', '.')]
datas += [('assets/buzz.ico', 'assets')]
datas += [('assets/buzz-icon-1024.png', 'assets')]
datas += [(shutil.which('ffmpeg'), '.')]

View file

@ -43,6 +43,7 @@ endif
clean:
rm -f $(LIBWHISPER)
rm -f whisper_cpp
rm -f buzz/whisper_cpp.py
rm -rf dist/* || true
@ -73,11 +74,15 @@ else
endif
endif
$(LIBWHISPER):
$(LIBWHISPER) whisper_cpp:
cmake -S whisper.cpp -B whisper.cpp/build/ $(CMAKE_FLAGS)
cmake --build whisper.cpp/build --verbose
cp whisper.cpp/build/$(LIBWHISPER) . || true
ls -lA whisper.cpp/build
ls -lA whisper.cpp/build/bin
cp whisper.cpp/build/bin/Debug/$(LIBWHISPER) . || true
cp whisper.cpp/build/bin/Debug/main whisper_cpp || true
cp whisper.cpp/build/$(LIBWHISPER) . || true
cp whisper.cpp/build/bin/main whisper_cpp || true
buzz/whisper_cpp.py: $(LIBWHISPER)
ctypesgen ./whisper.cpp/whisper.h -l$(LIBWHISPER) -o buzz/whisper_cpp.py

View file

@ -4,7 +4,7 @@ import os
import platform
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import humanize
import sounddevice
@ -22,7 +22,8 @@ from whisper import tokenizer
from .__version__ import VERSION
from .model_loader import ModelLoader
from .transcriber import FileTranscriber, OutputFormat, RecordingTranscriber
from .transcriber import (FileTranscriber, OutputFormat, RecordingTranscriber,
WhisperCppFileTranscriber)
from .whispr import LOADED_WHISPER_DLL, Task
APP_NAME = 'Buzz'
@ -267,27 +268,22 @@ class TranscriberProgressDialog(QProgressDialog):
class FileTranscriberObject(QObject):
download_model_progress = pyqtSignal(tuple)
event_received = pyqtSignal(object)
transcriber: FileTranscriber
def __init__(
self, model_path: str, use_whisper_cpp: bool, language: Optional[str],
self, model_path: str, language: Optional[str],
task: Task, file_path: str, output_file_path: str,
output_format: OutputFormat, word_level_timings: bool,
parent: Optional['QObject'], *args) -> None:
super().__init__(parent, *args)
self.transcriber = FileTranscriber(
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
on_download_model_chunk=self.on_download_model_progress,
model_path=model_path,
language=language, task=task, file_path=file_path,
output_file_path=output_file_path, output_format=output_format,
event_callback=self.on_file_transcriber_event,
word_level_timings=word_level_timings)
def on_download_model_progress(self, current: int, total: int):
self.download_model_progress.emit((current, total))
def on_file_transcriber_event(self, event: FileTranscriber.Event):
self.event_received.emit(event)
@ -386,7 +382,8 @@ class FileTranscriberWidget(QWidget):
enabled_word_level_timings = False
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None
file_transcriber: Optional[FileTranscriberObject] = None
file_transcriber: Optional[Union[FileTranscriberObject,
WhisperCppFileTranscriber]] = None
model_loader: Optional[ModelLoader] = None
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
@ -482,17 +479,28 @@ class FileTranscriberWidget(QWidget):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog = None
self.file_transcriber = FileTranscriberObject(
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
file_path=self.file_path,
language=self.selected_language, task=self.selected_task,
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
parent=self)
self.file_transcriber.event_received.connect(
self.on_transcriber_event)
self.file_transcriber.start()
if use_whisper_cpp:
self.file_transcriber = WhisperCppFileTranscriber(
model_path=model_path, file_path=self.file_path,
language=self.selected_language, task=self.selected_task,
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
)
self.file_transcriber.signals.progress.connect(
self.on_transcriber_progress)
self.file_transcriber.signals.completed.connect(
self.on_transcriber_complete)
self.pool.start(self.file_transcriber)
else:
self.file_transcriber = FileTranscriberObject(
model_path=model_path, file_path=self.file_path,
language=self.selected_language, task=self.selected_task,
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
parent=self)
self.file_transcriber.event_received.connect(
self.on_transcriber_event)
self.file_transcriber.start()
self.model_loader = ModelLoader(
name=model_name, use_whisper_cpp=use_whisper_cpp)
@ -522,22 +530,28 @@ class FileTranscriberWidget(QWidget):
def on_transcriber_event(self, event: FileTranscriber.Event):
if isinstance(event, FileTranscriber.ProgressEvent):
current_size = event.current_value
total_size = event.max_value
# Create a dialog
if self.transcriber_progress_dialog is None:
self.transcriber_progress_dialog = TranscriberProgressDialog(
file_path=self.file_path, total_size=total_size, parent=self)
self.transcriber_progress_dialog.canceled.connect(
self.on_cancel_transcriber_progress_dialog)
# Update the progress of the dialog unless it has
# been canceled before this progress update arrived
if self.transcriber_progress_dialog is not None:
self.transcriber_progress_dialog.update_progress(current_size)
self.on_transcriber_progress(
(event.current_value, event.max_value))
elif isinstance(event, FileTranscriber.CompletedTranscriptionEvent):
self.reset_transcription()
self.on_transcriber_complete()
def on_transcriber_progress(self, progress: Tuple[int, int]):
(current_size, total_size) = progress
# Create a dialog
if self.transcriber_progress_dialog is None:
self.transcriber_progress_dialog = TranscriberProgressDialog(
file_path=self.file_path, total_size=total_size, parent=self)
self.transcriber_progress_dialog.canceled.connect(
self.on_cancel_transcriber_progress_dialog)
# Update the progress of the dialog unless it has
# been canceled before this progress update arrived
if self.transcriber_progress_dialog is not None:
self.transcriber_progress_dialog.update_progress(current_size)
def on_transcriber_complete(self):
self.reset_transcription()
def on_cancel_transcriber_progress_dialog(self):
if self.file_transcriber is not None:
@ -547,6 +561,7 @@ class FileTranscriberWidget(QWidget):
def reset_transcription(self):
self.run_button.setDisabled(False)
if self.transcriber_progress_dialog is not None:
self.transcriber_progress_dialog.close()
self.transcriber_progress_dialog = None
def on_cancel_model_progress_dialog(self):

View file

@ -4,21 +4,25 @@ import logging
import multiprocessing
import os
import platform
import re
import subprocess
import sys
import tempfile
import threading
import typing
from dataclasses import dataclass
from multiprocessing.connection import Connection
from threading import Thread
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple
import ffmpeg
import numpy as np
import sounddevice
import stable_whisper
import whisper
from PyQt6.QtCore import QObject, QProcess, QRunnable, pyqtSignal, pyqtSlot
from sounddevice import PortAudioError
from .conn import pipe_stderr, pipe_stdout
from .conn import pipe_stderr
from .whispr import (Segment, Task, WhisperCpp, read_progress,
whisper_cpp_params)
@ -171,27 +175,27 @@ def to_timestamp(ms: float) -> str:
def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat):
file = open(path, 'w', encoding='utf-8')
logging.debug(
'Writing transcription output, path = %s, output format = %s, number of segments = %s', path, output_format, len(segments))
if output_format == OutputFormat.TXT:
for segment in segments:
file.write(segment.text)
with open(path, 'w', encoding='utf-8') as file:
if output_format == OutputFormat.TXT:
for segment in segments:
file.write(segment.text + ' ')
elif output_format == OutputFormat.VTT:
file.write('WEBVTT\n\n')
for segment in segments:
file.write(
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
file.write(f'{segment.text}\n\n')
elif output_format == OutputFormat.VTT:
file.write('WEBVTT\n\n')
for segment in segments:
file.write(
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
file.write(f'{segment.text}\n\n')
elif output_format == OutputFormat.SRT:
for (i, segment) in enumerate(segments):
file.write(f'{i+1}\n')
file.write(
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
file.write(f'{segment.text}\n\n')
file.close()
elif output_format == OutputFormat.SRT:
for (i, segment) in enumerate(segments):
file.write(f'{i+1}\n')
file.write(
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
file.write(f'{segment.text}\n\n')
if should_open:
try:
@ -201,6 +205,116 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f
subprocess.call([opener, path])
class WhisperCppFileTranscriber(QRunnable):
class Signals(QObject):
progress = pyqtSignal(tuple) # (current, total)
completed = pyqtSignal(bool)
error = pyqtSignal(str)
signals: Signals
duration_audio_ms = sys.maxsize # max int
segments: List[Segment] = []
def __init__(
self,
model_path: str, language: Optional[str], task: Task, file_path: str,
output_file_path: str, output_format: OutputFormat,
word_level_timings: bool, open_file_on_complete=True,
) -> None:
super(WhisperCppFileTranscriber, self).__init__()
self.file_path = file_path
self.output_file_path = output_file_path
self.language = language
self.task = task
self.open_file_on_complete = open_file_on_complete
self.output_format = output_format
self.word_level_timings = word_level_timings
self.model_path = model_path
self.signals = self.Signals()
self.process = QProcess()
self.process.readyReadStandardError.connect(self.read_std_err)
self.process.readyReadStandardOutput.connect(self.read_std_out)
self.process.finished.connect(self.on_process_finished)
@pyqtSlot()
def run(self):
logging.debug(
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model_path = %s',
self.file_path, self.language, self.task, self.output_file_path, self.output_format, self.model_path)
wav_file = tempfile.mktemp()+'.wav'
(
ffmpeg.input(self.file_path)
.output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
args = [
'--language', self.language if self.language is not None else 'en',
'--max-len', '1' if self.word_level_timings else '0',
'--model', self.model_path,
'--verbose'
]
if self.task == Task.TRANSLATE:
args.append('--translate')
args.append(wav_file)
logging.debug('Running whisper_cpp process, args = %s', args)
self.process.start('./whisper_cpp', args)
def on_process_finished(self):
status = self.process.exitStatus()
logging.debug('whisper_cpp process completed with status = %s', status)
if status == QProcess.ExitStatus.NormalExit:
self.signals.progress.emit(
(self.duration_audio_ms, self.duration_audio_ms))
write_output(
self.output_file_path, self.segments, self.open_file_on_complete, self.output_format)
self.signals.completed.emit(True)
def stop(self):
process_state = self.process.state()
if process_state == QProcess.ProcessState.Starting or process_state == QProcess.ProcessState.Running:
self.process.terminate()
def read_std_out(self):
output = self.process.readAllStandardOutput().data().decode('UTF-8').strip()
logging.debug('whisper_cpp (stdout): %s', output)
if len(output) > 0:
lines = output.split('\n')
for line in lines:
timings, text = line.split(' ')
start, end = self.parse_timings(timings)
segment = Segment(start, end, text.strip())
self.segments.append(segment)
self.signals.progress.emit((end, self.duration_audio_ms))
def parse_timings(self, timings: str) -> Tuple[int, int]:
start, end = timings[1:len(timings)-1].split(' --> ')
return self.parse_timestamp(start), self.parse_timestamp(end)
def parse_timestamp(self, timestamp: str) -> int:
hrs, mins, secs_ms = timestamp.split(':')
secs, ms = secs_ms.split('.')
return int(hrs)*60*60*1000 + int(mins)*60*1000 + int(secs)*1000 + int(ms)
def read_std_err(self):
output = self.process.readAllStandardError().data().decode('UTF-8').strip()
logging.debug('whisper_cpp (stderr): %s', output)
lines = output.split('\n')
for line in lines:
if line.startswith('main: processing'):
match = re.search(r'samples, (.*) sec', line)
if match is not None:
self.duration_audio_ms = round(float(match.group(1))*1000)
class FileTranscriber:
"""FileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file using the default program for opening txt files."""
@ -222,13 +336,10 @@ class FileTranscriber:
def __init__(
self,
model_path: str, use_whisper_cpp: bool,
language: Optional[str], task: Task, file_path: str,
model_path: str, language: Optional[str], task: Task, file_path: str,
output_file_path: str, output_format: OutputFormat,
word_level_timings: bool,
event_callback: Callable[[Event], None] = lambda *_: None,
on_download_model_chunk: Callable[[
int, int], None] = lambda *_: None,
open_file_on_complete=True) -> None:
self.file_path = file_path
self.output_file_path = output_file_path
@ -238,8 +349,6 @@ class FileTranscriber:
self.output_format = output_format
self.word_level_timings = word_level_timings
self.model_path = model_path
self.use_whisper_cpp = use_whisper_cpp
self.on_download_model_chunk = on_download_model_chunk
self.event_callback = event_callback
def start(self):
@ -258,31 +367,19 @@ class FileTranscriber:
return
self.event_callback(self.ProgressEvent(0, 100))
if self.use_whisper_cpp:
self.current_process = multiprocessing.Process(
target=transcribe_whisper_cpp,
args=(
send_pipe, self.model_path, self.file_path,
self.output_file_path, self.open_file_on_complete,
self.output_format,
self.language if self.language is not None else 'en',
self.task, True, True,
self.word_level_timings
))
else:
self.current_process = multiprocessing.Process(
target=transcribe_whisper,
args=(
send_pipe, self.model_path, self.file_path,
self.language, self.task, self.output_file_path,
self.open_file_on_complete, self.output_format,
self.word_level_timings
))
self.current_process = multiprocessing.Process(
target=transcribe_whisper,
args=(
send_pipe, self.model_path, self.file_path,
self.language, self.task, self.output_file_path,
self.open_file_on_complete, self.output_format,
self.word_level_timings
))
self.current_process.start()
thread = Thread(target=read_progress, args=(
recv_pipe, self.use_whisper_cpp,
recv_pipe,
lambda current_value, max_value: self.event_callback(self.ProgressEvent(current_value, max_value))))
thread.start()
@ -351,20 +448,3 @@ def transcribe_whisper(
write_output(output_file_path, list(
segments), open_file_on_complete, output_format)
def transcribe_whisper_cpp(
stderr_conn: Connection, model_path: str, audio: typing.Union[np.ndarray, str],
output_file_path: str, open_file_on_complete: bool, output_format: OutputFormat,
language: str, task: Task, print_realtime: bool, print_progress: bool,
word_level_timings: bool):
# TODO: capturing output does not work because ctypes functions
# See: https://stackoverflow.com/questions/9488560/capturing-print-output-from-shared-library-called-from-python-with-ctypes-module
with pipe_stdout(stderr_conn), pipe_stderr(stderr_conn):
model = WhisperCpp(model_path)
params = whisper_cpp_params(
language, task, word_level_timings, print_realtime, print_progress)
result = model.transcribe(audio=audio, params=params)
segments: List[Segment] = result.get('segments')
write_output(
output_file_path, segments, open_file_on_complete, output_format)

View file

@ -26,8 +26,8 @@ class Stopped(Exception):
@dataclass
class Segment:
start: float
end: float
start: int # start time in ms
end: int # end time in ms
text: str
@ -95,33 +95,12 @@ def tqdm_progress(line: str):
return int(percent_progress)
def whisper_cpp_progress(lines: str):
"""Extracts the progress of a whisper.cpp transcription.
The log lines have the following format:
whisper_full: progress = 20%\n
"""
# Example log line: "whisper_full: progress = 20%"
progress_lines = list(filter(lambda line: line.startswith(
'whisper_full: progress'), lines.split('\n')))
if len(progress_lines) == 0:
raise ValueError('No lines match whisper.cpp progress format')
last_word = progress_lines[-1].split(' ')[-1]
return min(int(last_word[:-1]), 100)
def read_progress(
pipe: Connection, use_whisper_cpp: bool,
progress_callback: Callable[[int, int], None]):
def read_progress(pipe: Connection, progress_callback: Callable[[int, int], None]):
while pipe.closed is False:
try:
recv = pipe.recv().strip()
if recv:
if use_whisper_cpp:
progress = whisper_cpp_progress(recv)
else:
progress = tqdm_progress(recv)
progress = tqdm_progress(recv)
progress_callback(progress, 100)
except ValueError:
pass

View file

@ -1,13 +1,16 @@
import logging
import os
import pathlib
import tempfile
import time
from unittest.mock import Mock
import pytest
from buzz.model_loader import ModelLoader
from buzz.transcriber import (FileTranscriber, OutputFormat,
RecordingTranscriber, to_timestamp)
RecordingTranscriber, WhisperCppFileTranscriber,
to_timestamp)
from buzz.whispr import Task
@ -33,6 +36,38 @@ class TestRecordingTranscriber:
assert transcriber is not None
class TestWhisperCppFileTranscriber:
@pytest.mark.parametrize(
'task,output_text',
[
(Task.TRANSCRIBE, 'Bienvenue dans Passe-Relle, un podcast'),
(Task.TRANSLATE, 'Welcome to Passe-Relle, a podcast'),
])
def test_transcribe(self, qtbot, tmp_path: pathlib.Path, task: Task, output_text: str):
output_file_path = tmp_path / 'whisper_cpp.txt'
if os.path.exists(output_file_path):
os.remove(output_file_path)
model_path = get_model_path('tiny', True)
transcriber = WhisperCppFileTranscriber(
model_path=model_path, language='fr',
task=task, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path.as_posix(), output_format=OutputFormat.TXT,
open_file_on_complete=False,
word_level_timings=False)
mock_progress = Mock()
with qtbot.waitSignal(transcriber.signals.completed, timeout=10*60*1000):
transcriber.signals.progress.connect(mock_progress)
transcriber.run()
assert os.path.isfile(output_file_path)
output_file = open(output_file_path, 'r', encoding='utf-8')
assert output_text in output_file.read()
mock_progress.assert_called()
class TestFileTranscriber:
def test_default_output_file(self):
srt = FileTranscriber.get_default_output_file_path(
@ -64,7 +99,7 @@ class TestFileTranscriber:
model_path = get_model_path('tiny', False)
transcriber = FileTranscriber(
model_path=model_path, use_whisper_cpp=False, language='fr',
model_path=model_path, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path.as_posix(), output_format=output_format,
open_file_on_complete=False, event_callback=event_callback,
@ -97,7 +132,7 @@ class TestFileTranscriber:
model_path = get_model_path('tiny', False)
transcriber = FileTranscriber(
model_path=model_path, use_whisper_cpp=False, language='fr',
model_path=model_path, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path, output_format=OutputFormat.TXT,
open_file_on_complete=False, event_callback=event_callback,
@ -109,32 +144,6 @@ class TestFileTranscriber:
# Assert that file was not created
assert os.path.isfile(output_file_path) is False
def test_transcribe_whisper_cpp(self):
output_file_path = os.path.join(
tempfile.gettempdir(), 'whisper_cpp.txt')
if os.path.exists(output_file_path):
os.remove(output_file_path)
events = []
def event_callback(event: FileTranscriber.Event):
events.append(event)
model_path = get_model_path('tiny', True)
transcriber = FileTranscriber(
model_path=model_path, use_whisper_cpp=True, language='fr',
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path, output_format=OutputFormat.TXT,
open_file_on_complete=False, event_callback=event_callback,
word_level_timings=False)
transcriber.start()
transcriber.join()
assert os.path.isfile(output_file_path)
output_file = open(output_file_path, 'r', encoding='utf-8')
assert 'Bienvenue dans Passe-Relle, un podcast' in output_file.read()
class TestToTimestamp:
def test_to_timestamp(self):

@ -1 +1 @@
Subproject commit 745f999d2dc32f2caeceb8e45d555ccd41e07669
Subproject commit 9ab012f37aa18ce9504ce2343e3c99ac37778498