mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-08 03:02:12 +02:00
Put FileTranscriber in QRunnable (#203)
This commit is contained in:
parent
fe2292c833
commit
209c0af3b8
|
@ -7,4 +7,4 @@ omit =
|
|||
directory = coverage/html
|
||||
|
||||
[report]
|
||||
fail_under = 78
|
||||
fail_under = 77
|
||||
|
|
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
|
@ -16,6 +16,7 @@ jobs:
|
|||
test:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: macos-latest
|
||||
|
@ -26,7 +27,7 @@ jobs:
|
|||
submodules: recursive
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.9.13'
|
||||
python-version: '3.10.7'
|
||||
|
||||
- name: Install Poetry Action
|
||||
uses: snok/install-poetry@v1.3.1
|
||||
|
@ -64,6 +65,7 @@ jobs:
|
|||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: macos-latest
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -7,5 +7,6 @@ build/
|
|||
.env
|
||||
htmlcov/
|
||||
libwhisper.*
|
||||
whisper_cpp
|
||||
whisper.dll
|
||||
whisper_cpp.py
|
||||
|
|
|
@ -2,4 +2,3 @@
|
|||
disable=
|
||||
C0114, # missing-module-docstring
|
||||
C0116, # missing-function-docstring
|
||||
C0115, # missing-class-docstring
|
||||
|
|
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
|
@ -1,6 +1,8 @@
|
|||
{
|
||||
"files.associations": {
|
||||
"Buzz.spec": "python"
|
||||
".coveragerc": "ini",
|
||||
"Buzz.spec": "python",
|
||||
"iosfwd": "cpp"
|
||||
},
|
||||
"files.exclude": {
|
||||
"**/.git": true,
|
||||
|
|
|
@ -17,6 +17,7 @@ datas += copy_metadata('tokenizers')
|
|||
datas += collect_data_files('whisper')
|
||||
datas += [('whisper.dll' if platform.system() ==
|
||||
'Windows' else 'libwhisper.*', '.')]
|
||||
datas += [('whisper_cpp', '.')]
|
||||
datas += [('assets/buzz.ico', 'assets')]
|
||||
datas += [('assets/buzz-icon-1024.png', 'assets')]
|
||||
datas += [(shutil.which('ffmpeg'), '.')]
|
||||
|
|
9
Makefile
9
Makefile
|
@ -43,6 +43,7 @@ endif
|
|||
|
||||
clean:
|
||||
rm -f $(LIBWHISPER)
|
||||
rm -f whisper_cpp
|
||||
rm -f buzz/whisper_cpp.py
|
||||
rm -rf dist/* || true
|
||||
|
||||
|
@ -73,11 +74,15 @@ else
|
|||
endif
|
||||
endif
|
||||
|
||||
$(LIBWHISPER):
|
||||
$(LIBWHISPER) whisper_cpp:
|
||||
cmake -S whisper.cpp -B whisper.cpp/build/ $(CMAKE_FLAGS)
|
||||
cmake --build whisper.cpp/build --verbose
|
||||
cp whisper.cpp/build/$(LIBWHISPER) . || true
|
||||
ls -lA whisper.cpp/build
|
||||
ls -lA whisper.cpp/build/bin
|
||||
cp whisper.cpp/build/bin/Debug/$(LIBWHISPER) . || true
|
||||
cp whisper.cpp/build/bin/Debug/main whisper_cpp || true
|
||||
cp whisper.cpp/build/$(LIBWHISPER) . || true
|
||||
cp whisper.cpp/build/bin/main whisper_cpp || true
|
||||
|
||||
buzz/whisper_cpp.py: $(LIBWHISPER)
|
||||
ctypesgen ./whisper.cpp/whisper.h -l$(LIBWHISPER) -o buzz/whisper_cpp.py
|
||||
|
|
87
buzz/gui.py
87
buzz/gui.py
|
@ -4,7 +4,7 @@ import os
|
|||
import platform
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import humanize
|
||||
import sounddevice
|
||||
|
@ -22,7 +22,8 @@ from whisper import tokenizer
|
|||
|
||||
from .__version__ import VERSION
|
||||
from .model_loader import ModelLoader
|
||||
from .transcriber import FileTranscriber, OutputFormat, RecordingTranscriber
|
||||
from .transcriber import (FileTranscriber, OutputFormat, RecordingTranscriber,
|
||||
WhisperCppFileTranscriber)
|
||||
from .whispr import LOADED_WHISPER_DLL, Task
|
||||
|
||||
APP_NAME = 'Buzz'
|
||||
|
@ -267,27 +268,22 @@ class TranscriberProgressDialog(QProgressDialog):
|
|||
|
||||
|
||||
class FileTranscriberObject(QObject):
|
||||
download_model_progress = pyqtSignal(tuple)
|
||||
event_received = pyqtSignal(object)
|
||||
transcriber: FileTranscriber
|
||||
|
||||
def __init__(
|
||||
self, model_path: str, use_whisper_cpp: bool, language: Optional[str],
|
||||
self, model_path: str, language: Optional[str],
|
||||
task: Task, file_path: str, output_file_path: str,
|
||||
output_format: OutputFormat, word_level_timings: bool,
|
||||
parent: Optional['QObject'], *args) -> None:
|
||||
super().__init__(parent, *args)
|
||||
self.transcriber = FileTranscriber(
|
||||
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
|
||||
on_download_model_chunk=self.on_download_model_progress,
|
||||
model_path=model_path,
|
||||
language=language, task=task, file_path=file_path,
|
||||
output_file_path=output_file_path, output_format=output_format,
|
||||
event_callback=self.on_file_transcriber_event,
|
||||
word_level_timings=word_level_timings)
|
||||
|
||||
def on_download_model_progress(self, current: int, total: int):
|
||||
self.download_model_progress.emit((current, total))
|
||||
|
||||
def on_file_transcriber_event(self, event: FileTranscriber.Event):
|
||||
self.event_received.emit(event)
|
||||
|
||||
|
@ -386,7 +382,8 @@ class FileTranscriberWidget(QWidget):
|
|||
enabled_word_level_timings = False
|
||||
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None
|
||||
file_transcriber: Optional[FileTranscriberObject] = None
|
||||
file_transcriber: Optional[Union[FileTranscriberObject,
|
||||
WhisperCppFileTranscriber]] = None
|
||||
model_loader: Optional[ModelLoader] = None
|
||||
|
||||
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
|
||||
|
@ -482,17 +479,28 @@ class FileTranscriberWidget(QWidget):
|
|||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
self.file_transcriber = FileTranscriberObject(
|
||||
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
|
||||
file_path=self.file_path,
|
||||
language=self.selected_language, task=self.selected_task,
|
||||
output_file_path=output_file, output_format=self.selected_output_format,
|
||||
word_level_timings=self.enabled_word_level_timings,
|
||||
parent=self)
|
||||
self.file_transcriber.event_received.connect(
|
||||
self.on_transcriber_event)
|
||||
|
||||
self.file_transcriber.start()
|
||||
if use_whisper_cpp:
|
||||
self.file_transcriber = WhisperCppFileTranscriber(
|
||||
model_path=model_path, file_path=self.file_path,
|
||||
language=self.selected_language, task=self.selected_task,
|
||||
output_file_path=output_file, output_format=self.selected_output_format,
|
||||
word_level_timings=self.enabled_word_level_timings,
|
||||
)
|
||||
self.file_transcriber.signals.progress.connect(
|
||||
self.on_transcriber_progress)
|
||||
self.file_transcriber.signals.completed.connect(
|
||||
self.on_transcriber_complete)
|
||||
self.pool.start(self.file_transcriber)
|
||||
else:
|
||||
self.file_transcriber = FileTranscriberObject(
|
||||
model_path=model_path, file_path=self.file_path,
|
||||
language=self.selected_language, task=self.selected_task,
|
||||
output_file_path=output_file, output_format=self.selected_output_format,
|
||||
word_level_timings=self.enabled_word_level_timings,
|
||||
parent=self)
|
||||
self.file_transcriber.event_received.connect(
|
||||
self.on_transcriber_event)
|
||||
self.file_transcriber.start()
|
||||
|
||||
self.model_loader = ModelLoader(
|
||||
name=model_name, use_whisper_cpp=use_whisper_cpp)
|
||||
|
@ -522,22 +530,28 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
def on_transcriber_event(self, event: FileTranscriber.Event):
|
||||
if isinstance(event, FileTranscriber.ProgressEvent):
|
||||
current_size = event.current_value
|
||||
total_size = event.max_value
|
||||
|
||||
# Create a dialog
|
||||
if self.transcriber_progress_dialog is None:
|
||||
self.transcriber_progress_dialog = TranscriberProgressDialog(
|
||||
file_path=self.file_path, total_size=total_size, parent=self)
|
||||
self.transcriber_progress_dialog.canceled.connect(
|
||||
self.on_cancel_transcriber_progress_dialog)
|
||||
|
||||
# Update the progress of the dialog unless it has
|
||||
# been canceled before this progress update arrived
|
||||
if self.transcriber_progress_dialog is not None:
|
||||
self.transcriber_progress_dialog.update_progress(current_size)
|
||||
self.on_transcriber_progress(
|
||||
(event.current_value, event.max_value))
|
||||
elif isinstance(event, FileTranscriber.CompletedTranscriptionEvent):
|
||||
self.reset_transcription()
|
||||
self.on_transcriber_complete()
|
||||
|
||||
def on_transcriber_progress(self, progress: Tuple[int, int]):
|
||||
(current_size, total_size) = progress
|
||||
|
||||
# Create a dialog
|
||||
if self.transcriber_progress_dialog is None:
|
||||
self.transcriber_progress_dialog = TranscriberProgressDialog(
|
||||
file_path=self.file_path, total_size=total_size, parent=self)
|
||||
self.transcriber_progress_dialog.canceled.connect(
|
||||
self.on_cancel_transcriber_progress_dialog)
|
||||
|
||||
# Update the progress of the dialog unless it has
|
||||
# been canceled before this progress update arrived
|
||||
if self.transcriber_progress_dialog is not None:
|
||||
self.transcriber_progress_dialog.update_progress(current_size)
|
||||
|
||||
def on_transcriber_complete(self):
|
||||
self.reset_transcription()
|
||||
|
||||
def on_cancel_transcriber_progress_dialog(self):
|
||||
if self.file_transcriber is not None:
|
||||
|
@ -547,6 +561,7 @@ class FileTranscriberWidget(QWidget):
|
|||
def reset_transcription(self):
|
||||
self.run_button.setDisabled(False)
|
||||
if self.transcriber_progress_dialog is not None:
|
||||
self.transcriber_progress_dialog.close()
|
||||
self.transcriber_progress_dialog = None
|
||||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
|
|
|
@ -4,21 +4,25 @@ import logging
|
|||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.connection import Connection
|
||||
from threading import Thread
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import sounddevice
|
||||
import stable_whisper
|
||||
import whisper
|
||||
from PyQt6.QtCore import QObject, QProcess, QRunnable, pyqtSignal, pyqtSlot
|
||||
from sounddevice import PortAudioError
|
||||
|
||||
from .conn import pipe_stderr, pipe_stdout
|
||||
from .conn import pipe_stderr
|
||||
from .whispr import (Segment, Task, WhisperCpp, read_progress,
|
||||
whisper_cpp_params)
|
||||
|
||||
|
@ -171,27 +175,27 @@ def to_timestamp(ms: float) -> str:
|
|||
|
||||
|
||||
def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat):
|
||||
file = open(path, 'w', encoding='utf-8')
|
||||
logging.debug(
|
||||
'Writing transcription output, path = %s, output format = %s, number of segments = %s', path, output_format, len(segments))
|
||||
|
||||
if output_format == OutputFormat.TXT:
|
||||
for segment in segments:
|
||||
file.write(segment.text)
|
||||
with open(path, 'w', encoding='utf-8') as file:
|
||||
if output_format == OutputFormat.TXT:
|
||||
for segment in segments:
|
||||
file.write(segment.text + ' ')
|
||||
|
||||
elif output_format == OutputFormat.VTT:
|
||||
file.write('WEBVTT\n\n')
|
||||
for segment in segments:
|
||||
file.write(
|
||||
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
|
||||
file.write(f'{segment.text}\n\n')
|
||||
elif output_format == OutputFormat.VTT:
|
||||
file.write('WEBVTT\n\n')
|
||||
for segment in segments:
|
||||
file.write(
|
||||
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
|
||||
file.write(f'{segment.text}\n\n')
|
||||
|
||||
elif output_format == OutputFormat.SRT:
|
||||
for (i, segment) in enumerate(segments):
|
||||
file.write(f'{i+1}\n')
|
||||
file.write(
|
||||
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
|
||||
file.write(f'{segment.text}\n\n')
|
||||
|
||||
file.close()
|
||||
elif output_format == OutputFormat.SRT:
|
||||
for (i, segment) in enumerate(segments):
|
||||
file.write(f'{i+1}\n')
|
||||
file.write(
|
||||
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
|
||||
file.write(f'{segment.text}\n\n')
|
||||
|
||||
if should_open:
|
||||
try:
|
||||
|
@ -201,6 +205,116 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f
|
|||
subprocess.call([opener, path])
|
||||
|
||||
|
||||
class WhisperCppFileTranscriber(QRunnable):
|
||||
class Signals(QObject):
|
||||
progress = pyqtSignal(tuple) # (current, total)
|
||||
completed = pyqtSignal(bool)
|
||||
error = pyqtSignal(str)
|
||||
|
||||
signals: Signals
|
||||
duration_audio_ms = sys.maxsize # max int
|
||||
segments: List[Segment] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str, language: Optional[str], task: Task, file_path: str,
|
||||
output_file_path: str, output_format: OutputFormat,
|
||||
word_level_timings: bool, open_file_on_complete=True,
|
||||
) -> None:
|
||||
super(WhisperCppFileTranscriber, self).__init__()
|
||||
|
||||
self.file_path = file_path
|
||||
self.output_file_path = output_file_path
|
||||
self.language = language
|
||||
self.task = task
|
||||
self.open_file_on_complete = open_file_on_complete
|
||||
self.output_format = output_format
|
||||
self.word_level_timings = word_level_timings
|
||||
self.model_path = model_path
|
||||
self.signals = self.Signals()
|
||||
|
||||
self.process = QProcess()
|
||||
self.process.readyReadStandardError.connect(self.read_std_err)
|
||||
self.process.readyReadStandardOutput.connect(self.read_std_out)
|
||||
self.process.finished.connect(self.on_process_finished)
|
||||
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
logging.debug(
|
||||
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model_path = %s',
|
||||
self.file_path, self.language, self.task, self.output_file_path, self.output_format, self.model_path)
|
||||
|
||||
wav_file = tempfile.mktemp()+'.wav'
|
||||
(
|
||||
ffmpeg.input(self.file_path)
|
||||
.output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE)
|
||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
|
||||
args = [
|
||||
'--language', self.language if self.language is not None else 'en',
|
||||
'--max-len', '1' if self.word_level_timings else '0',
|
||||
'--model', self.model_path,
|
||||
'--verbose'
|
||||
]
|
||||
if self.task == Task.TRANSLATE:
|
||||
args.append('--translate')
|
||||
args.append(wav_file)
|
||||
|
||||
logging.debug('Running whisper_cpp process, args = %s', args)
|
||||
|
||||
self.process.start('./whisper_cpp', args)
|
||||
|
||||
def on_process_finished(self):
|
||||
status = self.process.exitStatus()
|
||||
logging.debug('whisper_cpp process completed with status = %s', status)
|
||||
if status == QProcess.ExitStatus.NormalExit:
|
||||
self.signals.progress.emit(
|
||||
(self.duration_audio_ms, self.duration_audio_ms))
|
||||
write_output(
|
||||
self.output_file_path, self.segments, self.open_file_on_complete, self.output_format)
|
||||
|
||||
self.signals.completed.emit(True)
|
||||
|
||||
def stop(self):
|
||||
process_state = self.process.state()
|
||||
if process_state == QProcess.ProcessState.Starting or process_state == QProcess.ProcessState.Running:
|
||||
self.process.terminate()
|
||||
|
||||
def read_std_out(self):
|
||||
output = self.process.readAllStandardOutput().data().decode('UTF-8').strip()
|
||||
logging.debug('whisper_cpp (stdout): %s', output)
|
||||
|
||||
if len(output) > 0:
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
timings, text = line.split(' ')
|
||||
start, end = self.parse_timings(timings)
|
||||
segment = Segment(start, end, text.strip())
|
||||
self.segments.append(segment)
|
||||
self.signals.progress.emit((end, self.duration_audio_ms))
|
||||
|
||||
def parse_timings(self, timings: str) -> Tuple[int, int]:
|
||||
start, end = timings[1:len(timings)-1].split(' --> ')
|
||||
return self.parse_timestamp(start), self.parse_timestamp(end)
|
||||
|
||||
def parse_timestamp(self, timestamp: str) -> int:
|
||||
hrs, mins, secs_ms = timestamp.split(':')
|
||||
secs, ms = secs_ms.split('.')
|
||||
return int(hrs)*60*60*1000 + int(mins)*60*1000 + int(secs)*1000 + int(ms)
|
||||
|
||||
def read_std_err(self):
|
||||
output = self.process.readAllStandardError().data().decode('UTF-8').strip()
|
||||
logging.debug('whisper_cpp (stderr): %s', output)
|
||||
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
if line.startswith('main: processing'):
|
||||
match = re.search(r'samples, (.*) sec', line)
|
||||
if match is not None:
|
||||
self.duration_audio_ms = round(float(match.group(1))*1000)
|
||||
|
||||
|
||||
class FileTranscriber:
|
||||
"""FileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file using the default program for opening txt files."""
|
||||
|
||||
|
@ -222,13 +336,10 @@ class FileTranscriber:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str, use_whisper_cpp: bool,
|
||||
language: Optional[str], task: Task, file_path: str,
|
||||
model_path: str, language: Optional[str], task: Task, file_path: str,
|
||||
output_file_path: str, output_format: OutputFormat,
|
||||
word_level_timings: bool,
|
||||
event_callback: Callable[[Event], None] = lambda *_: None,
|
||||
on_download_model_chunk: Callable[[
|
||||
int, int], None] = lambda *_: None,
|
||||
open_file_on_complete=True) -> None:
|
||||
self.file_path = file_path
|
||||
self.output_file_path = output_file_path
|
||||
|
@ -238,8 +349,6 @@ class FileTranscriber:
|
|||
self.output_format = output_format
|
||||
self.word_level_timings = word_level_timings
|
||||
self.model_path = model_path
|
||||
self.use_whisper_cpp = use_whisper_cpp
|
||||
self.on_download_model_chunk = on_download_model_chunk
|
||||
self.event_callback = event_callback
|
||||
|
||||
def start(self):
|
||||
|
@ -258,31 +367,19 @@ class FileTranscriber:
|
|||
return
|
||||
|
||||
self.event_callback(self.ProgressEvent(0, 100))
|
||||
if self.use_whisper_cpp:
|
||||
self.current_process = multiprocessing.Process(
|
||||
target=transcribe_whisper_cpp,
|
||||
args=(
|
||||
send_pipe, self.model_path, self.file_path,
|
||||
self.output_file_path, self.open_file_on_complete,
|
||||
self.output_format,
|
||||
self.language if self.language is not None else 'en',
|
||||
self.task, True, True,
|
||||
self.word_level_timings
|
||||
))
|
||||
else:
|
||||
self.current_process = multiprocessing.Process(
|
||||
target=transcribe_whisper,
|
||||
args=(
|
||||
send_pipe, self.model_path, self.file_path,
|
||||
self.language, self.task, self.output_file_path,
|
||||
self.open_file_on_complete, self.output_format,
|
||||
self.word_level_timings
|
||||
))
|
||||
self.current_process = multiprocessing.Process(
|
||||
target=transcribe_whisper,
|
||||
args=(
|
||||
send_pipe, self.model_path, self.file_path,
|
||||
self.language, self.task, self.output_file_path,
|
||||
self.open_file_on_complete, self.output_format,
|
||||
self.word_level_timings
|
||||
))
|
||||
|
||||
self.current_process.start()
|
||||
|
||||
thread = Thread(target=read_progress, args=(
|
||||
recv_pipe, self.use_whisper_cpp,
|
||||
recv_pipe,
|
||||
lambda current_value, max_value: self.event_callback(self.ProgressEvent(current_value, max_value))))
|
||||
thread.start()
|
||||
|
||||
|
@ -351,20 +448,3 @@ def transcribe_whisper(
|
|||
|
||||
write_output(output_file_path, list(
|
||||
segments), open_file_on_complete, output_format)
|
||||
|
||||
|
||||
def transcribe_whisper_cpp(
|
||||
stderr_conn: Connection, model_path: str, audio: typing.Union[np.ndarray, str],
|
||||
output_file_path: str, open_file_on_complete: bool, output_format: OutputFormat,
|
||||
language: str, task: Task, print_realtime: bool, print_progress: bool,
|
||||
word_level_timings: bool):
|
||||
# TODO: capturing output does not work because ctypes functions
|
||||
# See: https://stackoverflow.com/questions/9488560/capturing-print-output-from-shared-library-called-from-python-with-ctypes-module
|
||||
with pipe_stdout(stderr_conn), pipe_stderr(stderr_conn):
|
||||
model = WhisperCpp(model_path)
|
||||
params = whisper_cpp_params(
|
||||
language, task, word_level_timings, print_realtime, print_progress)
|
||||
result = model.transcribe(audio=audio, params=params)
|
||||
segments: List[Segment] = result.get('segments')
|
||||
write_output(
|
||||
output_file_path, segments, open_file_on_complete, output_format)
|
||||
|
|
|
@ -26,8 +26,8 @@ class Stopped(Exception):
|
|||
|
||||
@dataclass
|
||||
class Segment:
|
||||
start: float
|
||||
end: float
|
||||
start: int # start time in ms
|
||||
end: int # end time in ms
|
||||
text: str
|
||||
|
||||
|
||||
|
@ -95,33 +95,12 @@ def tqdm_progress(line: str):
|
|||
return int(percent_progress)
|
||||
|
||||
|
||||
def whisper_cpp_progress(lines: str):
|
||||
"""Extracts the progress of a whisper.cpp transcription.
|
||||
|
||||
The log lines have the following format:
|
||||
whisper_full: progress = 20%\n
|
||||
"""
|
||||
|
||||
# Example log line: "whisper_full: progress = 20%"
|
||||
progress_lines = list(filter(lambda line: line.startswith(
|
||||
'whisper_full: progress'), lines.split('\n')))
|
||||
if len(progress_lines) == 0:
|
||||
raise ValueError('No lines match whisper.cpp progress format')
|
||||
last_word = progress_lines[-1].split(' ')[-1]
|
||||
return min(int(last_word[:-1]), 100)
|
||||
|
||||
|
||||
def read_progress(
|
||||
pipe: Connection, use_whisper_cpp: bool,
|
||||
progress_callback: Callable[[int, int], None]):
|
||||
def read_progress(pipe: Connection, progress_callback: Callable[[int, int], None]):
|
||||
while pipe.closed is False:
|
||||
try:
|
||||
recv = pipe.recv().strip()
|
||||
if recv:
|
||||
if use_whisper_cpp:
|
||||
progress = whisper_cpp_progress(recv)
|
||||
else:
|
||||
progress = tqdm_progress(recv)
|
||||
progress = tqdm_progress(recv)
|
||||
progress_callback(progress, 100)
|
||||
except ValueError:
|
||||
pass
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from buzz.model_loader import ModelLoader
|
||||
from buzz.transcriber import (FileTranscriber, OutputFormat,
|
||||
RecordingTranscriber, to_timestamp)
|
||||
RecordingTranscriber, WhisperCppFileTranscriber,
|
||||
to_timestamp)
|
||||
from buzz.whispr import Task
|
||||
|
||||
|
||||
|
@ -33,6 +36,38 @@ class TestRecordingTranscriber:
|
|||
assert transcriber is not None
|
||||
|
||||
|
||||
class TestWhisperCppFileTranscriber:
|
||||
@pytest.mark.parametrize(
|
||||
'task,output_text',
|
||||
[
|
||||
(Task.TRANSCRIBE, 'Bienvenue dans Passe-Relle, un podcast'),
|
||||
(Task.TRANSLATE, 'Welcome to Passe-Relle, a podcast'),
|
||||
])
|
||||
def test_transcribe(self, qtbot, tmp_path: pathlib.Path, task: Task, output_text: str):
|
||||
output_file_path = tmp_path / 'whisper_cpp.txt'
|
||||
if os.path.exists(output_file_path):
|
||||
os.remove(output_file_path)
|
||||
|
||||
model_path = get_model_path('tiny', True)
|
||||
transcriber = WhisperCppFileTranscriber(
|
||||
model_path=model_path, language='fr',
|
||||
task=task, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path.as_posix(), output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False,
|
||||
word_level_timings=False)
|
||||
mock_progress = Mock()
|
||||
with qtbot.waitSignal(transcriber.signals.completed, timeout=10*60*1000):
|
||||
transcriber.signals.progress.connect(mock_progress)
|
||||
transcriber.run()
|
||||
|
||||
assert os.path.isfile(output_file_path)
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert output_text in output_file.read()
|
||||
|
||||
mock_progress.assert_called()
|
||||
|
||||
|
||||
class TestFileTranscriber:
|
||||
def test_default_output_file(self):
|
||||
srt = FileTranscriber.get_default_output_file_path(
|
||||
|
@ -64,7 +99,7 @@ class TestFileTranscriber:
|
|||
|
||||
model_path = get_model_path('tiny', False)
|
||||
transcriber = FileTranscriber(
|
||||
model_path=model_path, use_whisper_cpp=False, language='fr',
|
||||
model_path=model_path, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path.as_posix(), output_format=output_format,
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
|
@ -97,7 +132,7 @@ class TestFileTranscriber:
|
|||
|
||||
model_path = get_model_path('tiny', False)
|
||||
transcriber = FileTranscriber(
|
||||
model_path=model_path, use_whisper_cpp=False, language='fr',
|
||||
model_path=model_path, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
|
@ -109,32 +144,6 @@ class TestFileTranscriber:
|
|||
# Assert that file was not created
|
||||
assert os.path.isfile(output_file_path) is False
|
||||
|
||||
def test_transcribe_whisper_cpp(self):
|
||||
output_file_path = os.path.join(
|
||||
tempfile.gettempdir(), 'whisper_cpp.txt')
|
||||
if os.path.exists(output_file_path):
|
||||
os.remove(output_file_path)
|
||||
|
||||
events = []
|
||||
|
||||
def event_callback(event: FileTranscriber.Event):
|
||||
events.append(event)
|
||||
|
||||
model_path = get_model_path('tiny', True)
|
||||
transcriber = FileTranscriber(
|
||||
model_path=model_path, use_whisper_cpp=True, language='fr',
|
||||
task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False, event_callback=event_callback,
|
||||
word_level_timings=False)
|
||||
transcriber.start()
|
||||
transcriber.join()
|
||||
|
||||
assert os.path.isfile(output_file_path)
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert 'Bienvenue dans Passe-Relle, un podcast' in output_file.read()
|
||||
|
||||
|
||||
class TestToTimestamp:
|
||||
def test_to_timestamp(self):
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 745f999d2dc32f2caeceb8e45d555ccd41e07669
|
||||
Subproject commit 9ab012f37aa18ce9504ce2343e3c99ac37778498
|
Loading…
Reference in a new issue