Add transcription viewer (#246)

This commit is contained in:
Chidi Williams 2022-12-16 01:23:29 +00:00 committed by GitHub
parent cccc8cea00
commit 7395de653f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 324 additions and 170 deletions

View file

@ -7,4 +7,4 @@ omit =
directory = coverage/html directory = coverage/html
[report] [report]
fail_under = 77 fail_under = 76

View file

@ -1,3 +1,4 @@
from dataclasses import dataclass
import enum import enum
import logging import logging
import os import os
@ -9,15 +10,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import humanize import humanize
import sounddevice import sounddevice
from PyQt6 import QtGui from PyQt6 import QtGui
from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread, from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread, pyqtSlot,
QThreadPool, QTimer, QUrl, pyqtSignal) QThreadPool, QTimer, QUrl, pyqtSignal)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon, from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor, QValidator) QKeySequence, QPixmap, QTextCursor, QValidator)
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog, from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QDialogButtonBox, QFileDialog, QGridLayout, QDialogButtonBox, QFileDialog, QGridLayout, QToolButton,
QHBoxLayout, QLabel, QLayout, QLineEdit, QLabel, QLayout, QLineEdit,
QMainWindow, QMessageBox, QPlainTextEdit, QMainWindow, QMessageBox, QPlainTextEdit,
QProgressDialog, QPushButton, QVBoxLayout, QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu,
QWidget) QWidget)
from requests import get from requests import get
from whisper import tokenizer from whisper import tokenizer
@ -25,10 +26,10 @@ from whisper import tokenizer
from .__version__ import VERSION from .__version__ import VERSION
from .model_loader import ModelLoader from .model_loader import ModelLoader
from .transcriber import (DEFAULT_WHISPER_TEMPERATURE, LOADED_WHISPER_DLL, from .transcriber import (DEFAULT_WHISPER_TEMPERATURE, LOADED_WHISPER_DLL,
SUPPORTED_OUTPUT_FORMATS, OutputFormat, SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
RecordingTranscriber, Task, RecordingTranscriber, Segment, Task,
WhisperCppFileTranscriber, WhisperFileTranscriber, WhisperCppFileTranscriber, WhisperFileTranscriber,
get_default_output_file_path) get_default_output_file_path, segments_to_text, write_output)
APP_NAME = 'Buzz' APP_NAME = 'Buzz'
@ -174,19 +175,9 @@ class QualityComboBox(QComboBox):
class TextDisplayBox(QPlainTextEdit): class TextDisplayBox(QPlainTextEdit):
"""TextDisplayBox is a read-only textbox""" """TextDisplayBox is a read-only textbox"""
os_styles = {
'Darwin': '''QTextEdit {
border: 0;
}'''
}
def __init__(self, parent: Optional[QWidget], *args) -> None: def __init__(self, parent: Optional[QWidget], *args) -> None:
super().__init__(parent, *args) super().__init__(parent, *args)
self.setReadOnly(True) self.setReadOnly(True)
self.setPlaceholderText('Click Record to begin...')
self.setStyleSheet(
'''QTextEdit {
} %s''' % get_platform_styles(self.os_styles))
class RecordButton(QPushButton): class RecordButton(QPushButton):
@ -371,6 +362,8 @@ class FileTranscriberWidget(QWidget):
model_loader: Optional[ModelLoader] = None model_loader: Optional[ModelLoader] = None
transcriber_thread: Optional[QThread] = None transcriber_thread: Optional[QThread] = None
transcribed = pyqtSignal() transcribed = pyqtSignal()
transcription_options: FileTranscriptionOptions
is_transcribing = False
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None: def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
super().__init__(parent) super().__init__(parent)
@ -378,6 +371,8 @@ class FileTranscriberWidget(QWidget):
layout = QGridLayout(self) layout = QGridLayout(self)
self.settings = Settings(self) self.settings = Settings(self)
self.transcription_options = FileTranscriptionOptions(
file_path=file_path)
self.file_path = file_path self.file_path = file_path
@ -462,15 +457,6 @@ class FileTranscriberWidget(QWidget):
self.initial_prompt = initial_prompt self.initial_prompt = initial_prompt
def on_click_run(self): def on_click_run(self):
default_path = get_default_output_file_path(
task=self.selected_task, input_file_path=self.file_path,
output_format=self.selected_output_format)
(output_file, _) = QFileDialog.getSaveFileName(
self, 'Save File', default_path, f'Text files (*.{self.selected_output_format.value})')
if output_file == '':
return
use_whisper_cpp = self.settings.get_enable_ggml_inference( use_whisper_cpp = self.settings.get_enable_ggml_inference(
) and self.selected_language is not None ) and self.selected_language is not None
@ -482,20 +468,18 @@ class FileTranscriberWidget(QWidget):
self.model_loader = ModelLoader( self.model_loader = ModelLoader(
name=model_name, use_whisper_cpp=use_whisper_cpp) name=model_name, use_whisper_cpp=use_whisper_cpp)
self.transcription_options = FileTranscriptionOptions(
file_path=self.file_path, language=self.selected_language,
task=self.selected_task, word_level_timings=self.enabled_word_level_timings,
temperature=self.temperature, initial_prompt=self.initial_prompt
)
if use_whisper_cpp: if use_whisper_cpp:
self.file_transcriber = WhisperCppFileTranscriber( self.file_transcriber = WhisperCppFileTranscriber(
file_path=self.file_path, language=self.selected_language, task=self.selected_task, self.transcription_options)
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
)
else: else:
self.file_transcriber = WhisperFileTranscriber( self.file_transcriber = WhisperFileTranscriber(
file_path=self.file_path, language=self.selected_language, task=self.selected_task, self.transcription_options)
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
temperature=self.temperature,
initial_prompt=self.initial_prompt
)
self.model_loader.moveToThread(self.transcriber_thread) self.model_loader.moveToThread(self.transcriber_thread)
self.file_transcriber.moveToThread(self.transcriber_thread) self.file_transcriber.moveToThread(self.transcriber_thread)
@ -515,6 +499,7 @@ class FileTranscriberWidget(QWidget):
self.model_loader.finished.connect(self.model_loader.deleteLater) self.model_loader.finished.connect(self.model_loader.deleteLater)
# Run the file transcriber after the model loads # Run the file transcriber after the model loads
self.model_loader.finished.connect(self.on_model_loaded)
self.model_loader.finished.connect(self.file_transcriber.run) self.model_loader.finished.connect(self.file_transcriber.run)
self.file_transcriber.progress.connect( self.file_transcriber.progress.connect(
@ -527,6 +512,9 @@ class FileTranscriberWidget(QWidget):
self.transcriber_thread.start() self.transcriber_thread.start()
def on_model_loaded(self):
self.is_transcribing = True
def on_download_model_progress(self, progress: Tuple[int, int]): def on_download_model_progress(self, progress: Tuple[int, int]):
(current_size, total_size) = progress (current_size, total_size) = progress
@ -545,22 +533,26 @@ class FileTranscriberWidget(QWidget):
self.reset_transcriber_controls() self.reset_transcriber_controls()
def on_transcriber_progress(self, progress: Tuple[int, int]): def on_transcriber_progress(self, progress: Tuple[int, int]):
logging.debug('received progress = %s', progress)
(current_size, total_size) = progress (current_size, total_size) = progress
# Create a dialog if self.is_transcribing:
if self.transcriber_progress_dialog is None: # Create a dialog
self.transcriber_progress_dialog = TranscriberProgressDialog( if self.transcriber_progress_dialog is None:
file_path=self.file_path, total_size=total_size, parent=self) self.transcriber_progress_dialog = TranscriberProgressDialog(
self.transcriber_progress_dialog.canceled.connect( file_path=self.file_path, total_size=total_size, parent=self)
self.on_cancel_transcriber_progress_dialog) self.transcriber_progress_dialog.canceled.connect(
self.on_cancel_transcriber_progress_dialog)
else:
# Update the progress of the dialog unless it has
# been canceled before this progress update arrived
self.transcriber_progress_dialog.update_progress(current_size)
# Update the progress of the dialog unless it has @pyqtSlot(tuple)
# been canceled before this progress update arrived def on_transcriber_complete(self, result: Tuple[int, List[Segment]]):
if self.transcriber_progress_dialog is not None: exit_code, segments = result
self.transcriber_progress_dialog.update_progress(current_size)
self.is_transcribing = False
def on_transcriber_complete(self, exit_code: int):
if self.transcriber_progress_dialog is not None: if self.transcriber_progress_dialog is not None:
self.transcriber_progress_dialog.reset() self.transcriber_progress_dialog.reset()
if exit_code != 0: if exit_code != 0:
@ -569,6 +561,10 @@ class FileTranscriberWidget(QWidget):
self.reset_transcriber_controls() self.reset_transcriber_controls()
self.transcribed.emit() self.transcribed.emit()
TranscriptionViewerWidget(
transcription_options=self.transcription_options,
segments=segments, parent=self, flags=Qt.WindowType.Window).show()
def on_cancel_transcriber_progress_dialog(self): def on_cancel_transcriber_progress_dialog(self):
if self.file_transcriber is not None: if self.file_transcriber is not None:
self.file_transcriber.stop() self.file_transcriber.stop()
@ -590,6 +586,70 @@ class FileTranscriberWidget(QWidget):
self.enabled_word_level_timings = value == Qt.CheckState.Checked.value self.enabled_word_level_timings = value == Qt.CheckState.Checked.value
class TranscriptionViewerWidget(QWidget):
segments: List[Segment]
transcription_options: FileTranscriptionOptions
def __init__(
self, transcription_options: FileTranscriptionOptions, segments: List[Segment],
parent: Optional['QWidget'] = None, flags: Qt.WindowType = Qt.WindowType.Widget,
) -> None:
super().__init__(parent, flags)
self.segments = segments
self.transcription_options = transcription_options
self.setMinimumWidth(500)
self.setMinimumHeight(500)
self.setWindowTitle(
f'Transcription - {get_short_file_path(transcription_options.file_path)}')
layout = QVBoxLayout(self)
text = segments_to_text(segments)
self.text_box = TextDisplayBox(self)
self.text_box.setPlainText(text)
layout.addWidget(self.text_box)
buttons_layout = QHBoxLayout()
buttons_layout.addStretch()
menu = QMenu()
actions = [QAction(text=output_format.value.upper(), parent=self)
for output_format in OutputFormat]
menu.addActions(actions)
menu.triggered.connect(self.on_menu_triggered)
export_button = QPushButton(self)
export_button.setText('Export')
export_button.setMenu(menu)
buttons_layout.addWidget(export_button)
layout.addLayout(buttons_layout)
self.setLayout(layout)
def on_menu_triggered(self, action: QAction):
output_format = OutputFormat[action.text()]
default_path = get_default_output_file_path(
task=self.transcription_options.task,
input_file_path=self.transcription_options.file_path,
output_format=output_format)
(output_file_path, _) = QFileDialog.getSaveFileName(
self, 'Save File', default_path, f'Text files (*.{output_format.value})')
if output_file_path == '':
return
write_output(path=output_file_path, segments=self.segments,
should_open=True, output_format=output_format)
class Settings(QSettings): class Settings(QSettings):
_ENABLE_GGML_INFERENCE = 'enable_ggml_inference' _ENABLE_GGML_INFERENCE = 'enable_ggml_inference'
@ -686,6 +746,7 @@ class RecordingTranscriberWidget(QWidget):
self.open_advanced_settings) self.open_advanced_settings)
self.text_box = TextDisplayBox(self) self.text_box = TextDisplayBox(self)
self.text_box.setPlaceholderText('Click Record to begin...')
widgets = [ widgets = [
((0, 5, FormLabel('Task:', self)), (5, 7, self.tasks_combo_box)), ((0, 5, FormLabel('Task:', self)), (5, 7, self.tasks_combo_box)),

View file

@ -1,10 +1,12 @@
import ctypes import ctypes
import datetime import datetime
import enum import enum
import json
import logging import logging
import multiprocessing import multiprocessing
import os import os
import platform import platform
import queue
import re import re
import subprocess import subprocess
import sys import sys
@ -195,28 +197,33 @@ class OutputFormat(enum.Enum):
VTT = 'vtt' VTT = 'vtt'
@dataclass
class FileTranscriptionOptions:
file_path: str
language: Optional[str] = None
task: Task = Task.TRANSCRIBE
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
class WhisperCppFileTranscriber(QObject): class WhisperCppFileTranscriber(QObject):
progress = pyqtSignal(tuple) # (current, total) progress = pyqtSignal(tuple) # (current, total)
completed = pyqtSignal(int) completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
error = pyqtSignal(str) error = pyqtSignal(str)
duration_audio_ms = sys.maxsize # max int duration_audio_ms = sys.maxsize # max int
segments: List[Segment] segments: List[Segment]
running = False running = False
def __init__( def __init__(
self, language: Optional[str], task: Task, file_path: str, self, transcription_options: FileTranscriptionOptions,
output_file_path: str, output_format: OutputFormat,
word_level_timings: bool, open_file_on_complete=True,
parent: Optional['QObject'] = None) -> None: parent: Optional['QObject'] = None) -> None:
super().__init__(parent) super().__init__(parent)
self.file_path = file_path self.file_path = transcription_options.file_path
self.output_file_path = output_file_path self.language = transcription_options.language
self.language = language self.task = transcription_options.task
self.task = task self.word_level_timings = transcription_options.word_level_timings
self.open_file_on_complete = open_file_on_complete
self.output_format = output_format
self.word_level_timings = word_level_timings
self.segments = [] self.segments = []
self.process = QProcess(self) self.process = QProcess(self)
@ -228,8 +235,8 @@ class WhisperCppFileTranscriber(QObject):
self.running = True self.running = True
logging.debug( logging.debug(
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model_path = %s', 'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s',
self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path) self.file_path, self.language, self.task, model_path)
wav_file = tempfile.mktemp()+'.wav' wav_file = tempfile.mktemp()+'.wav'
( (
@ -258,10 +265,8 @@ class WhisperCppFileTranscriber(QObject):
if status == QProcess.ExitStatus.NormalExit: if status == QProcess.ExitStatus.NormalExit:
self.progress.emit( self.progress.emit(
(self.duration_audio_ms, self.duration_audio_ms)) (self.duration_audio_ms, self.duration_audio_ms))
write_output(
self.output_file_path, self.segments, self.open_file_on_complete, self.output_format)
self.completed.emit(status == self.process.exitCode()) self.completed.emit((self.process.exitCode(), self.segments))
self.running = False self.running = False
def stop(self): def stop(self):
@ -282,7 +287,7 @@ class WhisperCppFileTranscriber(QObject):
segment = Segment(start, end, text.strip()) segment = Segment(start, end, text.strip())
self.segments.append(segment) self.segments.append(segment)
self.progress.emit((end, self.duration_audio_ms)) self.progress.emit((end, self.duration_audio_ms))
except UnicodeDecodeError: except (UnicodeDecodeError, ValueError):
pass pass
def parse_timings(self, timings: str) -> Tuple[int, int]: def parse_timings(self, timings: str) -> Tuple[int, int]:
@ -315,44 +320,39 @@ class WhisperFileTranscriber(QObject):
current_process: multiprocessing.Process current_process: multiprocessing.Process
progress = pyqtSignal(tuple) # (current, total) progress = pyqtSignal(tuple) # (current, total)
completed = pyqtSignal(int) completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
error = pyqtSignal(str) error = pyqtSignal(str)
running = False running = False
read_line_thread: Optional[Thread] = None read_line_thread: Optional[Thread] = None
def __init__(self, def __init__(
language: Optional[str], task: Task, file_path: str, self, transcription_options: FileTranscriptionOptions,
output_file_path: str, output_format: OutputFormat, parent: Optional['QObject'] = None) -> None:
word_level_timings: bool, temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE,
initial_prompt: str = '', open_file_on_complete=True, parent: Optional['QObject'] = None) -> None:
super().__init__(parent) super().__init__(parent)
self.file_path = file_path self.file_path = transcription_options.file_path
self.output_file_path = output_file_path self.language = transcription_options.language
self.language = language self.task = transcription_options.task
self.task = task self.word_level_timings = transcription_options.word_level_timings
self.open_file_on_complete = open_file_on_complete self.temperature = transcription_options.temperature
self.output_format = output_format self.initial_prompt = transcription_options.initial_prompt
self.word_level_timings = word_level_timings self.segments = []
self.temperature = temperature
self.initial_prompt = initial_prompt
@pyqtSlot(str) @pyqtSlot(str)
def run(self, model_path: str): def run(self, model_path: str):
self.running = True self.running = True
time_started = datetime.datetime.now() time_started = datetime.datetime.now()
logging.debug( logging.debug(
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model path = %s, temperature = %s, initial prompt length = %s', 'Starting whisper file transcription, file path = %s, language = %s, task = %s, model path = %s, temperature = %s, initial prompt length = %s',
self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path, self.temperature, len(self.initial_prompt)) self.file_path, self.language, self.task, model_path, self.temperature, len(self.initial_prompt))
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False) recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
self.current_process = multiprocessing.Process( self.current_process = multiprocessing.Process(
target=transcribe_whisper, target=transcribe_whisper,
args=( args=(
send_pipe, model_path, self.file_path, send_pipe, model_path, self.file_path,
self.language, self.task, self.output_file_path, self.language, self.task, self.word_level_timings,
self.open_file_on_complete, self.output_format,
self.word_level_timings,
self.temperature, self.initial_prompt self.temperature, self.initial_prompt
)) ))
@ -373,7 +373,9 @@ class WhisperFileTranscriber(QObject):
self.read_line_thread.join() self.read_line_thread.join()
self.completed.emit(self.current_process.exitcode) if self.current_process.exitcode != 0:
self.completed.emit((self.current_process.exitcode, []))
self.running = False self.running = False
def stop(self): def stop(self):
@ -381,6 +383,17 @@ class WhisperFileTranscriber(QObject):
self.current_process.terminate() self.current_process.terminate()
def on_whisper_stdout(self, line: str): def on_whisper_stdout(self, line: str):
if line.startswith('segments = '):
segments_dict = json.loads(line[11:])
segments = [Segment(
start=segment.get('start'),
end=segment.get('end'),
text=segment.get('text'),
) for segment in segments_dict]
self.current_process.join()
self.completed.emit((self.current_process.exitcode, segments))
return
try: try:
progress = int(line.split('|')[0].strip().strip('%')) progress = int(line.split('|')[0].strip().strip('%'))
self.progress.emit((progress, 100)) self.progress.emit((progress, 100))
@ -398,8 +411,7 @@ class WhisperFileTranscriber(QObject):
def transcribe_whisper( def transcribe_whisper(
stderr_conn: Connection, model_path: str, file_path: str, stderr_conn: Connection, model_path: str, file_path: str,
language: Optional[str], task: Task, output_file_path: str, language: Optional[str], task: Task,
open_file_on_complete: bool, output_format: OutputFormat,
word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str): word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str):
with pipe_stderr(stderr_conn): with pipe_stderr(stderr_conn):
model = whisper.load_model(model_path) model = whisper.load_model(model_path)
@ -418,15 +430,15 @@ def transcribe_whisper(
whisper_segments = stable_whisper.group_word_timestamps( whisper_segments = stable_whisper.group_word_timestamps(
result) if word_level_timings else result.get('segments') result) if word_level_timings else result.get('segments')
segments = map( segments = [
lambda segment: Segment( Segment(
start=segment.get('start')*1000, # s to ms start=int(segment.get('start')*1000),
end=segment.get('end')*1000, # s to ms end=int(segment.get('end')*1000),
text=segment.get('text')), text=segment.get('text'),
whisper_segments) ) for segment in whisper_segments]
segments_json = json.dumps(
write_output(output_file_path, list( segments, ensure_ascii=True, default=vars)
segments), open_file_on_complete, output_format) sys.stderr.write(f'segments = {segments_json}\n')
def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat): def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat):
@ -435,8 +447,11 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f
with open(path, 'w', encoding='utf-8') as file: with open(path, 'w', encoding='utf-8') as file:
if output_format == OutputFormat.TXT: if output_format == OutputFormat.TXT:
for segment in segments: for (i, segment) in enumerate(segments):
file.write(segment.text + ' ') file.write(segment.text)
if i < len(segments)-1:
file.write(' ')
file.write('\n')
elif output_format == OutputFormat.VTT: elif output_format == OutputFormat.VTT:
file.write('WEBVTT\n\n') file.write('WEBVTT\n\n')
@ -460,6 +475,16 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f
subprocess.call([opener, path]) subprocess.call([opener, path])
def segments_to_text(segments: List[Segment]) -> str:
result = ''
for (i, segment) in enumerate(segments):
result += f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n'
result += f'{segment.text}'
if i < len(segments)-1:
result += '\n\n'
return result
def to_timestamp(ms: float) -> str: def to_timestamp(ms: float) -> str:
hr = int(ms / (1000*60*60)) hr = int(ms / (1000*60*60))
ms = ms - hr * (1000*60*60) ms = ms - hr * (1000*60*60)

View file

@ -1,21 +1,24 @@
import logging import logging
import os import os
import pathlib import pathlib
from typing import Any, Callable
import platform
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
import sounddevice import sounddevice
from PyQt6.QtCore import Qt, QCoreApplication from PyQt6.QtCore import Qt, QCoreApplication
from PyQt6.QtGui import (QValidator) from PyQt6.QtGui import (QValidator)
from PyQt6.QtWidgets import (QPushButton)
from pytestqt.qtbot import QtBot from pytestqt.qtbot import QtBot
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application, from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
AudioDevicesComboBox, DownloadModelProgressDialog, AudioDevicesComboBox, DownloadModelProgressDialog,
FileTranscriberWidget, LanguagesComboBox, MainWindow, FileTranscriberWidget, LanguagesComboBox, MainWindow,
OutputFormatsComboBox, Quality, QualityComboBox, OutputFormatsComboBox, Quality, QualityComboBox,
Settings, TemperatureValidator, Settings, TemperatureValidator, TextDisplayBox,
TranscriberProgressDialog) TranscriberProgressDialog, TranscriptionViewerWidget)
from buzz.transcriber import OutputFormat from buzz.transcriber import FileTranscriptionOptions, OutputFormat, Segment
class TestApplication: class TestApplication:
@ -142,25 +145,29 @@ class TestTranscriberProgressDialog:
class TestDownloadModelProgressDialog: class TestDownloadModelProgressDialog:
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None) def test_should_show_dialog(self, qtbot: QtBot):
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
qtbot.add_widget(dialog)
assert dialog.labelText() == 'Downloading resources (0%, unknown time remaining)'
def test_should_show_dialog(self): def test_should_update_label_on_progress(self, qtbot: QtBot):
assert self.dialog.labelText() == 'Downloading resources (0%, unknown time remaining)' dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
qtbot.add_widget(dialog)
dialog.setValue(0)
def test_should_update_label_on_progress(self): dialog.setValue(12345)
self.dialog.setValue(0) assert dialog.labelText().startswith(
self.dialog.setValue(12345)
assert self.dialog.labelText().startswith(
'Downloading resources (1.00%') 'Downloading resources (1.00%')
self.dialog.setValue(123456) dialog.setValue(123456)
assert self.dialog.labelText().startswith( assert dialog.labelText().startswith(
'Downloading resources (10.00%') 'Downloading resources (10.00%')
# Other windows should not be processing while models are being downloaded # Other windows should not be processing while models are being downloaded
def test_should_be_an_application_modal(self): def test_should_be_an_application_modal(self, qtbot: QtBot):
assert self.dialog.windowModality() == Qt.WindowModality.ApplicationModal dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
qtbot.add_widget(dialog)
assert dialog.windowModality() == Qt.WindowModality.ApplicationModal
class TestFormatsComboBox: class TestFormatsComboBox:
@ -177,21 +184,32 @@ class TestMainWindow:
assert main_window is not None assert main_window is not None
def wait_until(callback: Callable[[], Any], timeout=0):
while True:
try:
QCoreApplication.processEvents()
callback()
return
except AssertionError:
pass
class TestFileTranscriberWidget: class TestFileTranscriberWidget:
def test_should_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path): @pytest.mark.skipif(condition=platform.system() == 'Windows', reason='Waiting for signal crashes process on Windows')
def test_should_transcribe(self, qtbot: QtBot):
widget = FileTranscriberWidget( widget = FileTranscriberWidget(
file_path='testdata/whisper-french.mp3', parent=None) file_path='testdata/whisper-french.mp3', parent=None)
qtbot.addWidget(widget) qtbot.addWidget(widget)
output_file_path = tmp_path / 'whisper.txt' # Waiting for a "transcribed" signal seems to work more consistently
# than checking for the opening of a TranscriptionViewerWidget.
# See also: https://github.com/pytest-dev/pytest-qt/issues/313
with qtbot.wait_signal(widget.transcribed, timeout=30*1000):
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock, transcription_viewer = widget.findChild(TranscriptionViewerWidget)
qtbot.wait_signal(widget.transcribed, timeout=10*1000)): assert isinstance(transcription_viewer, TranscriptionViewerWidget)
save_file_name_mock.return_value = (output_file_path, '') assert len(transcription_viewer.segments) > 0
widget.run_button.click()
output_file = open(output_file_path, 'r', encoding='utf-8')
assert 'Bienvenue dans Passe' in output_file.read()
@pytest.mark.skip(reason="transcription_started callback sometimes not getting called until all progress events are emitted") @pytest.mark.skip(reason="transcription_started callback sometimes not getting called until all progress events are emitted")
def test_should_transcribe_and_stop(self, qtbot: QtBot, tmp_path: pathlib.Path): def test_should_transcribe_and_stop(self, qtbot: QtBot, tmp_path: pathlib.Path):
@ -202,13 +220,12 @@ class TestFileTranscriberWidget:
output_file_path = tmp_path / 'whisper.txt' output_file_path = tmp_path / 'whisper.txt'
with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock): with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock):
save_file_name_mock.return_value = (output_file_path, '') save_file_name_mock.return_value = (str(output_file_path), '')
widget.run_button.click() widget.run_button.click()
def transcription_started(): def transcription_started():
QCoreApplication.processEvents() QCoreApplication.processEvents()
assert widget.transcriber_progress_dialog is not None assert widget.transcriber_progress_dialog is not None
logging.debug('asserted value = %s', widget.transcriber_progress_dialog.value())
assert widget.transcriber_progress_dialog.value() > 0 assert widget.transcriber_progress_dialog.value() > 0
qtbot.wait_until(transcription_started, timeout=30*1000) qtbot.wait_until(transcription_started, timeout=30*1000)
@ -270,3 +287,30 @@ class TestTemperatureValidator:
]) ])
def test_should_validate_temperature(self, text: str, state: QValidator.State): def test_should_validate_temperature(self, text: str, state: QValidator.State):
assert self.validator.validate(text, 0)[0] == state assert self.validator.validate(text, 0)[0] == state
class TestTranscriptionViewerWidget:
widget = TranscriptionViewerWidget(
transcription_options=FileTranscriptionOptions(
file_path='testdata/whisper-french.mp3'),
segments=[Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')])
def test_should_display_segments(self):
assert self.widget.windowTitle() == 'Transcription - whisper-french.mp3'
text_display_box = self.widget.findChild(TextDisplayBox)
assert isinstance(text_display_box, TextDisplayBox)
assert text_display_box.toPlainText(
) == '00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans'
def test_should_export_segments(self, tmp_path: pathlib.Path):
export_button = self.widget.findChild(QPushButton)
assert isinstance(export_button, QPushButton)
output_file_path = tmp_path / 'whisper.txt'
with patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock:
save_file_name_mock.return_value = (str(output_file_path), '')
export_button.menu().actions()[0].trigger()
output_file = open(output_file_path, 'r', encoding='utf-8')
assert 'Bien venue dans' in output_file.read()

View file

@ -1,7 +1,9 @@
import logging
import os import os
import pathlib import pathlib
import tempfile import tempfile
import time import time
from typing import List
from unittest.mock import Mock from unittest.mock import Mock
from PyQt6.QtCore import QCoreApplication from PyQt6.QtCore import QCoreApplication
@ -9,11 +11,11 @@ import pytest
from pytestqt.qtbot import QtBot from pytestqt.qtbot import QtBot
from buzz.model_loader import ModelLoader from buzz.model_loader import ModelLoader
from buzz.transcriber import (OutputFormat, RecordingTranscriber, Task, from buzz.transcriber import (FileTranscriptionOptions, OutputFormat, RecordingTranscriber, Segment, Task,
WhisperCpp, WhisperCppFileTranscriber, WhisperCpp, WhisperCppFileTranscriber,
WhisperFileTranscriber, WhisperFileTranscriber,
get_default_output_file_path, to_timestamp, get_default_output_file_path, to_timestamp,
whisper_cpp_params) whisper_cpp_params, write_output)
def get_model_path(model_name: str, use_whisper_cpp: bool) -> str: def get_model_path(model_name: str, use_whisper_cpp: bool) -> str:
@ -40,33 +42,31 @@ class TestRecordingTranscriber:
class TestWhisperCppFileTranscriber: class TestWhisperCppFileTranscriber:
@pytest.mark.parametrize( @pytest.mark.parametrize(
'task,output_text', 'word_level_timings,expected_segments',
[ [
(Task.TRANSCRIBE, 'Bienvenue dans Passe'), (False, [Segment(0, 1840, 'Bienvenue dans Passe Relle.')]),
(Task.TRANSLATE, 'Welcome to Passe-Relle'), (True, [Segment(30, 280, 'Bien'), Segment(280, 630, 'venue')])
]) ])
def test_transcribe(self, qtbot, tmp_path: pathlib.Path, task: Task, output_text: str): def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
output_file_path = tmp_path / 'whisper_cpp.txt' transcription_options = FileTranscriptionOptions(
if os.path.exists(output_file_path): language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
os.remove(output_file_path) word_level_timings=word_level_timings)
model_path = get_model_path('tiny', True) model_path = get_model_path('tiny', True)
transcriber = WhisperCppFileTranscriber( transcriber = WhisperCppFileTranscriber(
language='fr', task=task, file_path='testdata/whisper-french.mp3', transcription_options=transcription_options)
output_file_path=output_file_path.as_posix(), output_format=OutputFormat.TXT,
open_file_on_complete=False,
word_level_timings=False)
mock_progress = Mock() mock_progress = Mock()
with qtbot.waitSignal(transcriber.completed, timeout=10*60*1000): mock_completed = Mock()
transcriber.progress.connect(mock_progress) transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
with qtbot.waitSignal(transcriber.completed, timeout=10 * 60 * 1000):
transcriber.run(model_path) transcriber.run(model_path)
assert os.path.isfile(output_file_path)
output_file = open(output_file_path, 'r', encoding='utf-8')
assert output_text in output_file.read()
mock_progress.assert_called() mock_progress.assert_called()
exit_code, segments = mock_completed.call_args[0][0]
assert exit_code is 0
for expected_segment in expected_segments:
assert expected_segment in segments
class TestWhisperFileTranscriber: class TestWhisperFileTranscriber:
@ -82,36 +82,31 @@ class TestWhisperFileTranscriber:
assert srt.endswith('.srt') assert srt.endswith('.srt')
@pytest.mark.parametrize( @pytest.mark.parametrize(
'word_level_timings,output_format,output_text', 'word_level_timings,expected_segments',
[ [
(False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle'), (False, [
(False, OutputFormat.SRT, Segment(
'1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle'), 0, 6560,
(False, OutputFormat.VTT, ' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances'),
'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe'), ]),
(True, OutputFormat.SRT, (True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')])
'1\n00:00:00.040 --> 00:00:00.299\n Bien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n3\n00:00:00.329 --> 00:00:00.429\n P\n\n4\n00:00:00.429 --> 00:00:00.589\nasse-'),
]) ])
def test_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str): def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
output_file_path = tmp_path / f'whisper.{output_format.value.lower()}'
model_path = get_model_path('tiny', False) model_path = get_model_path('tiny', False)
mock_progress = Mock() mock_progress = Mock()
transcriber = WhisperFileTranscriber( mock_completed = Mock()
transcription_options = FileTranscriptionOptions(
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3', language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path.as_posix(), output_format=output_format,
open_file_on_complete=False,
word_level_timings=word_level_timings) word_level_timings=word_level_timings)
transcriber = WhisperFileTranscriber(
transcription_options=transcription_options)
transcriber.progress.connect(mock_progress) transcriber.progress.connect(mock_progress)
with qtbot.wait_signal(transcriber.completed, timeout=10*6000): transcriber.completed.connect(mock_completed)
with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
transcriber.run(model_path) transcriber.run(model_path)
assert os.path.isfile(output_file_path)
output_file = open(output_file_path, 'r', encoding='utf-8')
assert output_text in output_file.read()
QCoreApplication.processEvents() QCoreApplication.processEvents()
# Reports progress at 0, 0<progress<100, and 100 # Reports progress at 0, 0<progress<100, and 100
@ -120,7 +115,14 @@ class TestWhisperFileTranscriber:
assert any( assert any(
[call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list]) [call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
assert any( assert any(
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in mock_progress.call_args_list]) [(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in
mock_progress.call_args_list])
mock_completed.assert_called()
exit_code, segments = mock_completed.call_args[0][0]
assert exit_code is 0
for (i, expected_segment) in enumerate(expected_segments):
assert segments[i] == expected_segment
@pytest.mark.skip() @pytest.mark.skip()
def test_transcribe_stop(self): def test_transcribe_stop(self):
@ -129,11 +131,12 @@ class TestWhisperFileTranscriber:
os.remove(output_file_path) os.remove(output_file_path)
model_path = get_model_path('tiny', False) model_path = get_model_path('tiny', False)
transcriber = WhisperFileTranscriber( transcription_options = FileTranscriptionOptions(
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3', language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
output_file_path=output_file_path, output_format=OutputFormat.TXT,
open_file_on_complete=False,
word_level_timings=False) word_level_timings=False)
transcriber = WhisperFileTranscriber(
transcription_options=transcription_options)
transcriber.run(model_path) transcriber.run(model_path)
time.sleep(1) time.sleep(1)
transcriber.stop() transcriber.stop()
@ -159,3 +162,24 @@ class TestWhisperCpp:
audio='testdata/whisper-french.mp3', params=params) audio='testdata/whisper-french.mp3', params=params)
assert 'Bienvenue dans Passe' in result['text'] assert 'Bienvenue dans Passe' in result['text']
@pytest.mark.parametrize(
'output_format,output_text',
[
(OutputFormat.TXT, 'Bien venue dans\n'),
(
OutputFormat.SRT,
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
(OutputFormat.VTT,
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
])
def test_write_output(tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str):
output_file_path = tmp_path / 'whisper.txt'
segments = [Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')]
write_output(path=str(output_file_path), segments=segments,
should_open=False, output_format=output_format)
output_file = open(output_file_path, 'r', encoding='utf-8')
assert output_text == output_file.read()