mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-26 11:40:09 +02:00
Add transcription viewer (#246)
This commit is contained in:
parent
cccc8cea00
commit
7395de653f
|
@ -7,4 +7,4 @@ omit =
|
|||
directory = coverage/html
|
||||
|
||||
[report]
|
||||
fail_under = 77
|
||||
fail_under = 76
|
||||
|
|
157
buzz/gui.py
157
buzz/gui.py
|
@ -1,3 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
|
@ -9,15 +10,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
import humanize
|
||||
import sounddevice
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread,
|
||||
from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread, pyqtSlot,
|
||||
QThreadPool, QTimer, QUrl, pyqtSignal)
|
||||
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
||||
QKeySequence, QPixmap, QTextCursor, QValidator)
|
||||
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
||||
QDialogButtonBox, QFileDialog, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLayout, QLineEdit,
|
||||
QDialogButtonBox, QFileDialog, QGridLayout, QToolButton,
|
||||
QLabel, QLayout, QLineEdit,
|
||||
QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QProgressDialog, QPushButton, QVBoxLayout,
|
||||
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu,
|
||||
QWidget)
|
||||
from requests import get
|
||||
from whisper import tokenizer
|
||||
|
@ -25,10 +26,10 @@ from whisper import tokenizer
|
|||
from .__version__ import VERSION
|
||||
from .model_loader import ModelLoader
|
||||
from .transcriber import (DEFAULT_WHISPER_TEMPERATURE, LOADED_WHISPER_DLL,
|
||||
SUPPORTED_OUTPUT_FORMATS, OutputFormat,
|
||||
RecordingTranscriber, Task,
|
||||
SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
|
||||
RecordingTranscriber, Segment, Task,
|
||||
WhisperCppFileTranscriber, WhisperFileTranscriber,
|
||||
get_default_output_file_path)
|
||||
get_default_output_file_path, segments_to_text, write_output)
|
||||
|
||||
APP_NAME = 'Buzz'
|
||||
|
||||
|
@ -174,19 +175,9 @@ class QualityComboBox(QComboBox):
|
|||
class TextDisplayBox(QPlainTextEdit):
|
||||
"""TextDisplayBox is a read-only textbox"""
|
||||
|
||||
os_styles = {
|
||||
'Darwin': '''QTextEdit {
|
||||
border: 0;
|
||||
}'''
|
||||
}
|
||||
|
||||
def __init__(self, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(parent, *args)
|
||||
self.setReadOnly(True)
|
||||
self.setPlaceholderText('Click Record to begin...')
|
||||
self.setStyleSheet(
|
||||
'''QTextEdit {
|
||||
} %s''' % get_platform_styles(self.os_styles))
|
||||
|
||||
|
||||
class RecordButton(QPushButton):
|
||||
|
@ -371,6 +362,8 @@ class FileTranscriberWidget(QWidget):
|
|||
model_loader: Optional[ModelLoader] = None
|
||||
transcriber_thread: Optional[QThread] = None
|
||||
transcribed = pyqtSignal()
|
||||
transcription_options: FileTranscriptionOptions
|
||||
is_transcribing = False
|
||||
|
||||
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
|
||||
super().__init__(parent)
|
||||
|
@ -378,6 +371,8 @@ class FileTranscriberWidget(QWidget):
|
|||
layout = QGridLayout(self)
|
||||
|
||||
self.settings = Settings(self)
|
||||
self.transcription_options = FileTranscriptionOptions(
|
||||
file_path=file_path)
|
||||
|
||||
self.file_path = file_path
|
||||
|
||||
|
@ -462,15 +457,6 @@ class FileTranscriberWidget(QWidget):
|
|||
self.initial_prompt = initial_prompt
|
||||
|
||||
def on_click_run(self):
|
||||
default_path = get_default_output_file_path(
|
||||
task=self.selected_task, input_file_path=self.file_path,
|
||||
output_format=self.selected_output_format)
|
||||
(output_file, _) = QFileDialog.getSaveFileName(
|
||||
self, 'Save File', default_path, f'Text files (*.{self.selected_output_format.value})')
|
||||
|
||||
if output_file == '':
|
||||
return
|
||||
|
||||
use_whisper_cpp = self.settings.get_enable_ggml_inference(
|
||||
) and self.selected_language is not None
|
||||
|
||||
|
@ -482,20 +468,18 @@ class FileTranscriberWidget(QWidget):
|
|||
self.model_loader = ModelLoader(
|
||||
name=model_name, use_whisper_cpp=use_whisper_cpp)
|
||||
|
||||
self.transcription_options = FileTranscriptionOptions(
|
||||
file_path=self.file_path, language=self.selected_language,
|
||||
task=self.selected_task, word_level_timings=self.enabled_word_level_timings,
|
||||
temperature=self.temperature, initial_prompt=self.initial_prompt
|
||||
)
|
||||
|
||||
if use_whisper_cpp:
|
||||
self.file_transcriber = WhisperCppFileTranscriber(
|
||||
file_path=self.file_path, language=self.selected_language, task=self.selected_task,
|
||||
output_file_path=output_file, output_format=self.selected_output_format,
|
||||
word_level_timings=self.enabled_word_level_timings,
|
||||
)
|
||||
self.transcription_options)
|
||||
else:
|
||||
self.file_transcriber = WhisperFileTranscriber(
|
||||
file_path=self.file_path, language=self.selected_language, task=self.selected_task,
|
||||
output_file_path=output_file, output_format=self.selected_output_format,
|
||||
word_level_timings=self.enabled_word_level_timings,
|
||||
temperature=self.temperature,
|
||||
initial_prompt=self.initial_prompt
|
||||
)
|
||||
self.transcription_options)
|
||||
|
||||
self.model_loader.moveToThread(self.transcriber_thread)
|
||||
self.file_transcriber.moveToThread(self.transcriber_thread)
|
||||
|
@ -515,6 +499,7 @@ class FileTranscriberWidget(QWidget):
|
|||
self.model_loader.finished.connect(self.model_loader.deleteLater)
|
||||
|
||||
# Run the file transcriber after the model loads
|
||||
self.model_loader.finished.connect(self.on_model_loaded)
|
||||
self.model_loader.finished.connect(self.file_transcriber.run)
|
||||
|
||||
self.file_transcriber.progress.connect(
|
||||
|
@ -527,6 +512,9 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
self.transcriber_thread.start()
|
||||
|
||||
def on_model_loaded(self):
|
||||
self.is_transcribing = True
|
||||
|
||||
def on_download_model_progress(self, progress: Tuple[int, int]):
|
||||
(current_size, total_size) = progress
|
||||
|
||||
|
@ -545,22 +533,26 @@ class FileTranscriberWidget(QWidget):
|
|||
self.reset_transcriber_controls()
|
||||
|
||||
def on_transcriber_progress(self, progress: Tuple[int, int]):
|
||||
logging.debug('received progress = %s', progress)
|
||||
(current_size, total_size) = progress
|
||||
|
||||
# Create a dialog
|
||||
if self.transcriber_progress_dialog is None:
|
||||
self.transcriber_progress_dialog = TranscriberProgressDialog(
|
||||
file_path=self.file_path, total_size=total_size, parent=self)
|
||||
self.transcriber_progress_dialog.canceled.connect(
|
||||
self.on_cancel_transcriber_progress_dialog)
|
||||
if self.is_transcribing:
|
||||
# Create a dialog
|
||||
if self.transcriber_progress_dialog is None:
|
||||
self.transcriber_progress_dialog = TranscriberProgressDialog(
|
||||
file_path=self.file_path, total_size=total_size, parent=self)
|
||||
self.transcriber_progress_dialog.canceled.connect(
|
||||
self.on_cancel_transcriber_progress_dialog)
|
||||
else:
|
||||
# Update the progress of the dialog unless it has
|
||||
# been canceled before this progress update arrived
|
||||
self.transcriber_progress_dialog.update_progress(current_size)
|
||||
|
||||
# Update the progress of the dialog unless it has
|
||||
# been canceled before this progress update arrived
|
||||
if self.transcriber_progress_dialog is not None:
|
||||
self.transcriber_progress_dialog.update_progress(current_size)
|
||||
@pyqtSlot(tuple)
|
||||
def on_transcriber_complete(self, result: Tuple[int, List[Segment]]):
|
||||
exit_code, segments = result
|
||||
|
||||
self.is_transcribing = False
|
||||
|
||||
def on_transcriber_complete(self, exit_code: int):
|
||||
if self.transcriber_progress_dialog is not None:
|
||||
self.transcriber_progress_dialog.reset()
|
||||
if exit_code != 0:
|
||||
|
@ -569,6 +561,10 @@ class FileTranscriberWidget(QWidget):
|
|||
self.reset_transcriber_controls()
|
||||
self.transcribed.emit()
|
||||
|
||||
TranscriptionViewerWidget(
|
||||
transcription_options=self.transcription_options,
|
||||
segments=segments, parent=self, flags=Qt.WindowType.Window).show()
|
||||
|
||||
def on_cancel_transcriber_progress_dialog(self):
|
||||
if self.file_transcriber is not None:
|
||||
self.file_transcriber.stop()
|
||||
|
@ -590,6 +586,70 @@ class FileTranscriberWidget(QWidget):
|
|||
self.enabled_word_level_timings = value == Qt.CheckState.Checked.value
|
||||
|
||||
|
||||
class TranscriptionViewerWidget(QWidget):
|
||||
segments: List[Segment]
|
||||
transcription_options: FileTranscriptionOptions
|
||||
|
||||
def __init__(
|
||||
self, transcription_options: FileTranscriptionOptions, segments: List[Segment],
|
||||
parent: Optional['QWidget'] = None, flags: Qt.WindowType = Qt.WindowType.Widget,
|
||||
) -> None:
|
||||
super().__init__(parent, flags)
|
||||
self.segments = segments
|
||||
self.transcription_options = transcription_options
|
||||
|
||||
self.setMinimumWidth(500)
|
||||
self.setMinimumHeight(500)
|
||||
|
||||
self.setWindowTitle(
|
||||
f'Transcription - {get_short_file_path(transcription_options.file_path)}')
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
text = segments_to_text(segments)
|
||||
|
||||
self.text_box = TextDisplayBox(self)
|
||||
self.text_box.setPlainText(text)
|
||||
|
||||
layout.addWidget(self.text_box)
|
||||
|
||||
buttons_layout = QHBoxLayout()
|
||||
buttons_layout.addStretch()
|
||||
|
||||
menu = QMenu()
|
||||
actions = [QAction(text=output_format.value.upper(), parent=self)
|
||||
for output_format in OutputFormat]
|
||||
menu.addActions(actions)
|
||||
|
||||
menu.triggered.connect(self.on_menu_triggered)
|
||||
|
||||
export_button = QPushButton(self)
|
||||
export_button.setText('Export')
|
||||
export_button.setMenu(menu)
|
||||
|
||||
buttons_layout.addWidget(export_button)
|
||||
layout.addLayout(buttons_layout)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def on_menu_triggered(self, action: QAction):
|
||||
output_format = OutputFormat[action.text()]
|
||||
|
||||
default_path = get_default_output_file_path(
|
||||
task=self.transcription_options.task,
|
||||
input_file_path=self.transcription_options.file_path,
|
||||
output_format=output_format)
|
||||
|
||||
(output_file_path, _) = QFileDialog.getSaveFileName(
|
||||
self, 'Save File', default_path, f'Text files (*.{output_format.value})')
|
||||
|
||||
if output_file_path == '':
|
||||
return
|
||||
|
||||
write_output(path=output_file_path, segments=self.segments,
|
||||
should_open=True, output_format=output_format)
|
||||
|
||||
|
||||
class Settings(QSettings):
|
||||
_ENABLE_GGML_INFERENCE = 'enable_ggml_inference'
|
||||
|
||||
|
@ -686,6 +746,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.open_advanced_settings)
|
||||
|
||||
self.text_box = TextDisplayBox(self)
|
||||
self.text_box.setPlaceholderText('Click Record to begin...')
|
||||
|
||||
widgets = [
|
||||
((0, 5, FormLabel('Task:', self)), (5, 7, self.tasks_combo_box)),
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import ctypes
|
||||
import datetime
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import queue
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -195,28 +197,33 @@ class OutputFormat(enum.Enum):
|
|||
VTT = 'vtt'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileTranscriptionOptions:
|
||||
file_path: str
|
||||
language: Optional[str] = None
|
||||
task: Task = Task.TRANSCRIBE
|
||||
word_level_timings: bool = False
|
||||
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
|
||||
initial_prompt: str = ''
|
||||
|
||||
|
||||
class WhisperCppFileTranscriber(QObject):
|
||||
progress = pyqtSignal(tuple) # (current, total)
|
||||
completed = pyqtSignal(int)
|
||||
completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
|
||||
error = pyqtSignal(str)
|
||||
duration_audio_ms = sys.maxsize # max int
|
||||
segments: List[Segment]
|
||||
running = False
|
||||
|
||||
def __init__(
|
||||
self, language: Optional[str], task: Task, file_path: str,
|
||||
output_file_path: str, output_format: OutputFormat,
|
||||
word_level_timings: bool, open_file_on_complete=True,
|
||||
self, transcription_options: FileTranscriptionOptions,
|
||||
parent: Optional['QObject'] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.file_path = file_path
|
||||
self.output_file_path = output_file_path
|
||||
self.language = language
|
||||
self.task = task
|
||||
self.open_file_on_complete = open_file_on_complete
|
||||
self.output_format = output_format
|
||||
self.word_level_timings = word_level_timings
|
||||
self.file_path = transcription_options.file_path
|
||||
self.language = transcription_options.language
|
||||
self.task = transcription_options.task
|
||||
self.word_level_timings = transcription_options.word_level_timings
|
||||
self.segments = []
|
||||
|
||||
self.process = QProcess(self)
|
||||
|
@ -228,8 +235,8 @@ class WhisperCppFileTranscriber(QObject):
|
|||
self.running = True
|
||||
|
||||
logging.debug(
|
||||
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model_path = %s',
|
||||
self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path)
|
||||
'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s',
|
||||
self.file_path, self.language, self.task, model_path)
|
||||
|
||||
wav_file = tempfile.mktemp()+'.wav'
|
||||
(
|
||||
|
@ -258,10 +265,8 @@ class WhisperCppFileTranscriber(QObject):
|
|||
if status == QProcess.ExitStatus.NormalExit:
|
||||
self.progress.emit(
|
||||
(self.duration_audio_ms, self.duration_audio_ms))
|
||||
write_output(
|
||||
self.output_file_path, self.segments, self.open_file_on_complete, self.output_format)
|
||||
|
||||
self.completed.emit(status == self.process.exitCode())
|
||||
self.completed.emit((self.process.exitCode(), self.segments))
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
|
@ -282,7 +287,7 @@ class WhisperCppFileTranscriber(QObject):
|
|||
segment = Segment(start, end, text.strip())
|
||||
self.segments.append(segment)
|
||||
self.progress.emit((end, self.duration_audio_ms))
|
||||
except UnicodeDecodeError:
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
def parse_timings(self, timings: str) -> Tuple[int, int]:
|
||||
|
@ -315,44 +320,39 @@ class WhisperFileTranscriber(QObject):
|
|||
|
||||
current_process: multiprocessing.Process
|
||||
progress = pyqtSignal(tuple) # (current, total)
|
||||
completed = pyqtSignal(int)
|
||||
completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
|
||||
error = pyqtSignal(str)
|
||||
running = False
|
||||
read_line_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(self,
|
||||
language: Optional[str], task: Task, file_path: str,
|
||||
output_file_path: str, output_format: OutputFormat,
|
||||
word_level_timings: bool, temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE,
|
||||
initial_prompt: str = '', open_file_on_complete=True, parent: Optional['QObject'] = None) -> None:
|
||||
def __init__(
|
||||
self, transcription_options: FileTranscriptionOptions,
|
||||
parent: Optional['QObject'] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.file_path = file_path
|
||||
self.output_file_path = output_file_path
|
||||
self.language = language
|
||||
self.task = task
|
||||
self.open_file_on_complete = open_file_on_complete
|
||||
self.output_format = output_format
|
||||
self.word_level_timings = word_level_timings
|
||||
self.temperature = temperature
|
||||
self.initial_prompt = initial_prompt
|
||||
self.file_path = transcription_options.file_path
|
||||
self.language = transcription_options.language
|
||||
self.task = transcription_options.task
|
||||
self.word_level_timings = transcription_options.word_level_timings
|
||||
self.temperature = transcription_options.temperature
|
||||
self.initial_prompt = transcription_options.initial_prompt
|
||||
self.segments = []
|
||||
|
||||
@pyqtSlot(str)
|
||||
def run(self, model_path: str):
|
||||
self.running = True
|
||||
time_started = datetime.datetime.now()
|
||||
logging.debug(
|
||||
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model path = %s, temperature = %s, initial prompt length = %s',
|
||||
self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path, self.temperature, len(self.initial_prompt))
|
||||
'Starting whisper file transcription, file path = %s, language = %s, task = %s, model path = %s, temperature = %s, initial prompt length = %s',
|
||||
self.file_path, self.language, self.task, model_path, self.temperature, len(self.initial_prompt))
|
||||
|
||||
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
self.current_process = multiprocessing.Process(
|
||||
target=transcribe_whisper,
|
||||
args=(
|
||||
send_pipe, model_path, self.file_path,
|
||||
self.language, self.task, self.output_file_path,
|
||||
self.open_file_on_complete, self.output_format,
|
||||
self.word_level_timings,
|
||||
self.language, self.task, self.word_level_timings,
|
||||
self.temperature, self.initial_prompt
|
||||
))
|
||||
|
||||
|
@ -373,7 +373,9 @@ class WhisperFileTranscriber(QObject):
|
|||
|
||||
self.read_line_thread.join()
|
||||
|
||||
self.completed.emit(self.current_process.exitcode)
|
||||
if self.current_process.exitcode != 0:
|
||||
self.completed.emit((self.current_process.exitcode, []))
|
||||
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
|
@ -381,6 +383,17 @@ class WhisperFileTranscriber(QObject):
|
|||
self.current_process.terminate()
|
||||
|
||||
def on_whisper_stdout(self, line: str):
|
||||
if line.startswith('segments = '):
|
||||
segments_dict = json.loads(line[11:])
|
||||
segments = [Segment(
|
||||
start=segment.get('start'),
|
||||
end=segment.get('end'),
|
||||
text=segment.get('text'),
|
||||
) for segment in segments_dict]
|
||||
self.current_process.join()
|
||||
self.completed.emit((self.current_process.exitcode, segments))
|
||||
return
|
||||
|
||||
try:
|
||||
progress = int(line.split('|')[0].strip().strip('%'))
|
||||
self.progress.emit((progress, 100))
|
||||
|
@ -398,8 +411,7 @@ class WhisperFileTranscriber(QObject):
|
|||
|
||||
def transcribe_whisper(
|
||||
stderr_conn: Connection, model_path: str, file_path: str,
|
||||
language: Optional[str], task: Task, output_file_path: str,
|
||||
open_file_on_complete: bool, output_format: OutputFormat,
|
||||
language: Optional[str], task: Task,
|
||||
word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str):
|
||||
with pipe_stderr(stderr_conn):
|
||||
model = whisper.load_model(model_path)
|
||||
|
@ -418,15 +430,15 @@ def transcribe_whisper(
|
|||
whisper_segments = stable_whisper.group_word_timestamps(
|
||||
result) if word_level_timings else result.get('segments')
|
||||
|
||||
segments = map(
|
||||
lambda segment: Segment(
|
||||
start=segment.get('start')*1000, # s to ms
|
||||
end=segment.get('end')*1000, # s to ms
|
||||
text=segment.get('text')),
|
||||
whisper_segments)
|
||||
|
||||
write_output(output_file_path, list(
|
||||
segments), open_file_on_complete, output_format)
|
||||
segments = [
|
||||
Segment(
|
||||
start=int(segment.get('start')*1000),
|
||||
end=int(segment.get('end')*1000),
|
||||
text=segment.get('text'),
|
||||
) for segment in whisper_segments]
|
||||
segments_json = json.dumps(
|
||||
segments, ensure_ascii=True, default=vars)
|
||||
sys.stderr.write(f'segments = {segments_json}\n')
|
||||
|
||||
|
||||
def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat):
|
||||
|
@ -435,8 +447,11 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f
|
|||
|
||||
with open(path, 'w', encoding='utf-8') as file:
|
||||
if output_format == OutputFormat.TXT:
|
||||
for segment in segments:
|
||||
file.write(segment.text + ' ')
|
||||
for (i, segment) in enumerate(segments):
|
||||
file.write(segment.text)
|
||||
if i < len(segments)-1:
|
||||
file.write(' ')
|
||||
file.write('\n')
|
||||
|
||||
elif output_format == OutputFormat.VTT:
|
||||
file.write('WEBVTT\n\n')
|
||||
|
@ -460,6 +475,16 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f
|
|||
subprocess.call([opener, path])
|
||||
|
||||
|
||||
def segments_to_text(segments: List[Segment]) -> str:
|
||||
result = ''
|
||||
for (i, segment) in enumerate(segments):
|
||||
result += f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n'
|
||||
result += f'{segment.text}'
|
||||
if i < len(segments)-1:
|
||||
result += '\n\n'
|
||||
return result
|
||||
|
||||
|
||||
def to_timestamp(ms: float) -> str:
|
||||
hr = int(ms / (1000*60*60))
|
||||
ms = ms - hr * (1000*60*60)
|
||||
|
|
|
@ -1,21 +1,24 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Any, Callable
|
||||
import platform
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import sounddevice
|
||||
from PyQt6.QtCore import Qt, QCoreApplication
|
||||
from PyQt6.QtGui import (QValidator)
|
||||
from PyQt6.QtWidgets import (QPushButton)
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
|
||||
AudioDevicesComboBox, DownloadModelProgressDialog,
|
||||
FileTranscriberWidget, LanguagesComboBox, MainWindow,
|
||||
OutputFormatsComboBox, Quality, QualityComboBox,
|
||||
Settings, TemperatureValidator,
|
||||
TranscriberProgressDialog)
|
||||
from buzz.transcriber import OutputFormat
|
||||
Settings, TemperatureValidator, TextDisplayBox,
|
||||
TranscriberProgressDialog, TranscriptionViewerWidget)
|
||||
from buzz.transcriber import FileTranscriptionOptions, OutputFormat, Segment
|
||||
|
||||
|
||||
class TestApplication:
|
||||
|
@ -142,25 +145,29 @@ class TestTranscriberProgressDialog:
|
|||
|
||||
|
||||
class TestDownloadModelProgressDialog:
|
||||
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
|
||||
def test_should_show_dialog(self, qtbot: QtBot):
|
||||
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
assert dialog.labelText() == 'Downloading resources (0%, unknown time remaining)'
|
||||
|
||||
def test_should_show_dialog(self):
|
||||
assert self.dialog.labelText() == 'Downloading resources (0%, unknown time remaining)'
|
||||
def test_should_update_label_on_progress(self, qtbot: QtBot):
|
||||
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
dialog.setValue(0)
|
||||
|
||||
def test_should_update_label_on_progress(self):
|
||||
self.dialog.setValue(0)
|
||||
|
||||
self.dialog.setValue(12345)
|
||||
assert self.dialog.labelText().startswith(
|
||||
dialog.setValue(12345)
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading resources (1.00%')
|
||||
|
||||
self.dialog.setValue(123456)
|
||||
assert self.dialog.labelText().startswith(
|
||||
dialog.setValue(123456)
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading resources (10.00%')
|
||||
|
||||
# Other windows should not be processing while models are being downloaded
|
||||
def test_should_be_an_application_modal(self):
|
||||
assert self.dialog.windowModality() == Qt.WindowModality.ApplicationModal
|
||||
def test_should_be_an_application_modal(self, qtbot: QtBot):
|
||||
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
assert dialog.windowModality() == Qt.WindowModality.ApplicationModal
|
||||
|
||||
|
||||
class TestFormatsComboBox:
|
||||
|
@ -177,21 +184,32 @@ class TestMainWindow:
|
|||
assert main_window is not None
|
||||
|
||||
|
||||
def wait_until(callback: Callable[[], Any], timeout=0):
|
||||
while True:
|
||||
try:
|
||||
QCoreApplication.processEvents()
|
||||
callback()
|
||||
return
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
class TestFileTranscriberWidget:
|
||||
def test_should_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path):
|
||||
@pytest.mark.skipif(condition=platform.system() == 'Windows', reason='Waiting for signal crashes process on Windows')
|
||||
def test_should_transcribe(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_path='testdata/whisper-french.mp3', parent=None)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
output_file_path = tmp_path / 'whisper.txt'
|
||||
# Waiting for a "transcribed" signal seems to work more consistently
|
||||
# than checking for the opening of a TranscriptionViewerWidget.
|
||||
# See also: https://github.com/pytest-dev/pytest-qt/issues/313
|
||||
with qtbot.wait_signal(widget.transcribed, timeout=30*1000):
|
||||
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
|
||||
|
||||
with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock,
|
||||
qtbot.wait_signal(widget.transcribed, timeout=10*1000)):
|
||||
save_file_name_mock.return_value = (output_file_path, '')
|
||||
widget.run_button.click()
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert 'Bienvenue dans Passe' in output_file.read()
|
||||
transcription_viewer = widget.findChild(TranscriptionViewerWidget)
|
||||
assert isinstance(transcription_viewer, TranscriptionViewerWidget)
|
||||
assert len(transcription_viewer.segments) > 0
|
||||
|
||||
@pytest.mark.skip(reason="transcription_started callback sometimes not getting called until all progress events are emitted")
|
||||
def test_should_transcribe_and_stop(self, qtbot: QtBot, tmp_path: pathlib.Path):
|
||||
|
@ -202,13 +220,12 @@ class TestFileTranscriberWidget:
|
|||
output_file_path = tmp_path / 'whisper.txt'
|
||||
|
||||
with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock):
|
||||
save_file_name_mock.return_value = (output_file_path, '')
|
||||
save_file_name_mock.return_value = (str(output_file_path), '')
|
||||
widget.run_button.click()
|
||||
|
||||
def transcription_started():
|
||||
QCoreApplication.processEvents()
|
||||
assert widget.transcriber_progress_dialog is not None
|
||||
logging.debug('asserted value = %s', widget.transcriber_progress_dialog.value())
|
||||
assert widget.transcriber_progress_dialog.value() > 0
|
||||
qtbot.wait_until(transcription_started, timeout=30*1000)
|
||||
|
||||
|
@ -270,3 +287,30 @@ class TestTemperatureValidator:
|
|||
])
|
||||
def test_should_validate_temperature(self, text: str, state: QValidator.State):
|
||||
assert self.validator.validate(text, 0)[0] == state
|
||||
|
||||
|
||||
class TestTranscriptionViewerWidget:
|
||||
widget = TranscriptionViewerWidget(
|
||||
transcription_options=FileTranscriptionOptions(
|
||||
file_path='testdata/whisper-french.mp3'),
|
||||
segments=[Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')])
|
||||
|
||||
def test_should_display_segments(self):
|
||||
assert self.widget.windowTitle() == 'Transcription - whisper-french.mp3'
|
||||
|
||||
text_display_box = self.widget.findChild(TextDisplayBox)
|
||||
assert isinstance(text_display_box, TextDisplayBox)
|
||||
assert text_display_box.toPlainText(
|
||||
) == '00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans'
|
||||
|
||||
def test_should_export_segments(self, tmp_path: pathlib.Path):
|
||||
export_button = self.widget.findChild(QPushButton)
|
||||
assert isinstance(export_button, QPushButton)
|
||||
|
||||
output_file_path = tmp_path / 'whisper.txt'
|
||||
with patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock:
|
||||
save_file_name_mock.return_value = (str(output_file_path), '')
|
||||
export_button.menu().actions()[0].trigger()
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert 'Bien venue dans' in output_file.read()
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
from PyQt6.QtCore import QCoreApplication
|
||||
|
@ -9,11 +11,11 @@ import pytest
|
|||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.model_loader import ModelLoader
|
||||
from buzz.transcriber import (OutputFormat, RecordingTranscriber, Task,
|
||||
from buzz.transcriber import (FileTranscriptionOptions, OutputFormat, RecordingTranscriber, Segment, Task,
|
||||
WhisperCpp, WhisperCppFileTranscriber,
|
||||
WhisperFileTranscriber,
|
||||
get_default_output_file_path, to_timestamp,
|
||||
whisper_cpp_params)
|
||||
whisper_cpp_params, write_output)
|
||||
|
||||
|
||||
def get_model_path(model_name: str, use_whisper_cpp: bool) -> str:
|
||||
|
@ -40,33 +42,31 @@ class TestRecordingTranscriber:
|
|||
|
||||
class TestWhisperCppFileTranscriber:
|
||||
@pytest.mark.parametrize(
|
||||
'task,output_text',
|
||||
'word_level_timings,expected_segments',
|
||||
[
|
||||
(Task.TRANSCRIBE, 'Bienvenue dans Passe'),
|
||||
(Task.TRANSLATE, 'Welcome to Passe-Relle'),
|
||||
(False, [Segment(0, 1840, 'Bienvenue dans Passe Relle.')]),
|
||||
(True, [Segment(30, 280, 'Bien'), Segment(280, 630, 'venue')])
|
||||
])
|
||||
def test_transcribe(self, qtbot, tmp_path: pathlib.Path, task: Task, output_text: str):
|
||||
output_file_path = tmp_path / 'whisper_cpp.txt'
|
||||
if os.path.exists(output_file_path):
|
||||
os.remove(output_file_path)
|
||||
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
|
||||
transcription_options = FileTranscriptionOptions(
|
||||
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
word_level_timings=word_level_timings)
|
||||
|
||||
model_path = get_model_path('tiny', True)
|
||||
transcriber = WhisperCppFileTranscriber(
|
||||
language='fr', task=task, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path.as_posix(), output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False,
|
||||
word_level_timings=False)
|
||||
transcription_options=transcription_options)
|
||||
mock_progress = Mock()
|
||||
with qtbot.waitSignal(transcriber.completed, timeout=10*60*1000):
|
||||
transcriber.progress.connect(mock_progress)
|
||||
mock_completed = Mock()
|
||||
transcriber.progress.connect(mock_progress)
|
||||
transcriber.completed.connect(mock_completed)
|
||||
with qtbot.waitSignal(transcriber.completed, timeout=10 * 60 * 1000):
|
||||
transcriber.run(model_path)
|
||||
|
||||
assert os.path.isfile(output_file_path)
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert output_text in output_file.read()
|
||||
|
||||
mock_progress.assert_called()
|
||||
exit_code, segments = mock_completed.call_args[0][0]
|
||||
assert exit_code is 0
|
||||
for expected_segment in expected_segments:
|
||||
assert expected_segment in segments
|
||||
|
||||
|
||||
class TestWhisperFileTranscriber:
|
||||
|
@ -82,36 +82,31 @@ class TestWhisperFileTranscriber:
|
|||
assert srt.endswith('.srt')
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'word_level_timings,output_format,output_text',
|
||||
'word_level_timings,expected_segments',
|
||||
[
|
||||
(False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle'),
|
||||
(False, OutputFormat.SRT,
|
||||
'1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle'),
|
||||
(False, OutputFormat.VTT,
|
||||
'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe'),
|
||||
(True, OutputFormat.SRT,
|
||||
'1\n00:00:00.040 --> 00:00:00.299\n Bien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n3\n00:00:00.329 --> 00:00:00.429\n P\n\n4\n00:00:00.429 --> 00:00:00.589\nasse-'),
|
||||
(False, [
|
||||
Segment(
|
||||
0, 6560,
|
||||
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances'),
|
||||
]),
|
||||
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')])
|
||||
])
|
||||
def test_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str):
|
||||
output_file_path = tmp_path / f'whisper.{output_format.value.lower()}'
|
||||
|
||||
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
|
||||
model_path = get_model_path('tiny', False)
|
||||
|
||||
mock_progress = Mock()
|
||||
transcriber = WhisperFileTranscriber(
|
||||
mock_completed = Mock()
|
||||
transcription_options = FileTranscriptionOptions(
|
||||
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path.as_posix(), output_format=output_format,
|
||||
open_file_on_complete=False,
|
||||
word_level_timings=word_level_timings)
|
||||
|
||||
transcriber = WhisperFileTranscriber(
|
||||
transcription_options=transcription_options)
|
||||
transcriber.progress.connect(mock_progress)
|
||||
with qtbot.wait_signal(transcriber.completed, timeout=10*6000):
|
||||
transcriber.completed.connect(mock_completed)
|
||||
with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
|
||||
transcriber.run(model_path)
|
||||
|
||||
assert os.path.isfile(output_file_path)
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert output_text in output_file.read()
|
||||
|
||||
QCoreApplication.processEvents()
|
||||
|
||||
# Reports progress at 0, 0<progress<100, and 100
|
||||
|
@ -120,7 +115,14 @@ class TestWhisperFileTranscriber:
|
|||
assert any(
|
||||
[call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
|
||||
assert any(
|
||||
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in mock_progress.call_args_list])
|
||||
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in
|
||||
mock_progress.call_args_list])
|
||||
|
||||
mock_completed.assert_called()
|
||||
exit_code, segments = mock_completed.call_args[0][0]
|
||||
assert exit_code is 0
|
||||
for (i, expected_segment) in enumerate(expected_segments):
|
||||
assert segments[i] == expected_segment
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_transcribe_stop(self):
|
||||
|
@ -129,11 +131,12 @@ class TestWhisperFileTranscriber:
|
|||
os.remove(output_file_path)
|
||||
|
||||
model_path = get_model_path('tiny', False)
|
||||
transcriber = WhisperFileTranscriber(
|
||||
transcription_options = FileTranscriptionOptions(
|
||||
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
||||
open_file_on_complete=False,
|
||||
word_level_timings=False)
|
||||
|
||||
transcriber = WhisperFileTranscriber(
|
||||
transcription_options=transcription_options)
|
||||
transcriber.run(model_path)
|
||||
time.sleep(1)
|
||||
transcriber.stop()
|
||||
|
@ -159,3 +162,24 @@ class TestWhisperCpp:
|
|||
audio='testdata/whisper-french.mp3', params=params)
|
||||
|
||||
assert 'Bienvenue dans Passe' in result['text']
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'output_format,output_text',
|
||||
[
|
||||
(OutputFormat.TXT, 'Bien venue dans\n'),
|
||||
(
|
||||
OutputFormat.SRT,
|
||||
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
(OutputFormat.VTT,
|
||||
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
])
|
||||
def test_write_output(tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str):
|
||||
output_file_path = tmp_path / 'whisper.txt'
|
||||
segments = [Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')]
|
||||
|
||||
write_output(path=str(output_file_path), segments=segments,
|
||||
should_open=False, output_format=output_format)
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert output_text == output_file.read()
|
||||
|
|
Loading…
Reference in a new issue