mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-29 05:00:21 +02:00
Add transcription viewer (#246)
This commit is contained in:
parent
cccc8cea00
commit
7395de653f
|
@ -7,4 +7,4 @@ omit =
|
||||||
directory = coverage/html
|
directory = coverage/html
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
fail_under = 77
|
fail_under = 76
|
||||||
|
|
157
buzz/gui.py
157
buzz/gui.py
|
@ -1,3 +1,4 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -9,15 +10,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
import humanize
|
import humanize
|
||||||
import sounddevice
|
import sounddevice
|
||||||
from PyQt6 import QtGui
|
from PyQt6 import QtGui
|
||||||
from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread,
|
from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread, pyqtSlot,
|
||||||
QThreadPool, QTimer, QUrl, pyqtSignal)
|
QThreadPool, QTimer, QUrl, pyqtSignal)
|
||||||
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
||||||
QKeySequence, QPixmap, QTextCursor, QValidator)
|
QKeySequence, QPixmap, QTextCursor, QValidator)
|
||||||
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
||||||
QDialogButtonBox, QFileDialog, QGridLayout,
|
QDialogButtonBox, QFileDialog, QGridLayout, QToolButton,
|
||||||
QHBoxLayout, QLabel, QLayout, QLineEdit,
|
QLabel, QLayout, QLineEdit,
|
||||||
QMainWindow, QMessageBox, QPlainTextEdit,
|
QMainWindow, QMessageBox, QPlainTextEdit,
|
||||||
QProgressDialog, QPushButton, QVBoxLayout,
|
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu,
|
||||||
QWidget)
|
QWidget)
|
||||||
from requests import get
|
from requests import get
|
||||||
from whisper import tokenizer
|
from whisper import tokenizer
|
||||||
|
@ -25,10 +26,10 @@ from whisper import tokenizer
|
||||||
from .__version__ import VERSION
|
from .__version__ import VERSION
|
||||||
from .model_loader import ModelLoader
|
from .model_loader import ModelLoader
|
||||||
from .transcriber import (DEFAULT_WHISPER_TEMPERATURE, LOADED_WHISPER_DLL,
|
from .transcriber import (DEFAULT_WHISPER_TEMPERATURE, LOADED_WHISPER_DLL,
|
||||||
SUPPORTED_OUTPUT_FORMATS, OutputFormat,
|
SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
|
||||||
RecordingTranscriber, Task,
|
RecordingTranscriber, Segment, Task,
|
||||||
WhisperCppFileTranscriber, WhisperFileTranscriber,
|
WhisperCppFileTranscriber, WhisperFileTranscriber,
|
||||||
get_default_output_file_path)
|
get_default_output_file_path, segments_to_text, write_output)
|
||||||
|
|
||||||
APP_NAME = 'Buzz'
|
APP_NAME = 'Buzz'
|
||||||
|
|
||||||
|
@ -174,19 +175,9 @@ class QualityComboBox(QComboBox):
|
||||||
class TextDisplayBox(QPlainTextEdit):
|
class TextDisplayBox(QPlainTextEdit):
|
||||||
"""TextDisplayBox is a read-only textbox"""
|
"""TextDisplayBox is a read-only textbox"""
|
||||||
|
|
||||||
os_styles = {
|
|
||||||
'Darwin': '''QTextEdit {
|
|
||||||
border: 0;
|
|
||||||
}'''
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, parent: Optional[QWidget], *args) -> None:
|
def __init__(self, parent: Optional[QWidget], *args) -> None:
|
||||||
super().__init__(parent, *args)
|
super().__init__(parent, *args)
|
||||||
self.setReadOnly(True)
|
self.setReadOnly(True)
|
||||||
self.setPlaceholderText('Click Record to begin...')
|
|
||||||
self.setStyleSheet(
|
|
||||||
'''QTextEdit {
|
|
||||||
} %s''' % get_platform_styles(self.os_styles))
|
|
||||||
|
|
||||||
|
|
||||||
class RecordButton(QPushButton):
|
class RecordButton(QPushButton):
|
||||||
|
@ -371,6 +362,8 @@ class FileTranscriberWidget(QWidget):
|
||||||
model_loader: Optional[ModelLoader] = None
|
model_loader: Optional[ModelLoader] = None
|
||||||
transcriber_thread: Optional[QThread] = None
|
transcriber_thread: Optional[QThread] = None
|
||||||
transcribed = pyqtSignal()
|
transcribed = pyqtSignal()
|
||||||
|
transcription_options: FileTranscriptionOptions
|
||||||
|
is_transcribing = False
|
||||||
|
|
||||||
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
|
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
|
@ -378,6 +371,8 @@ class FileTranscriberWidget(QWidget):
|
||||||
layout = QGridLayout(self)
|
layout = QGridLayout(self)
|
||||||
|
|
||||||
self.settings = Settings(self)
|
self.settings = Settings(self)
|
||||||
|
self.transcription_options = FileTranscriptionOptions(
|
||||||
|
file_path=file_path)
|
||||||
|
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
|
|
||||||
|
@ -462,15 +457,6 @@ class FileTranscriberWidget(QWidget):
|
||||||
self.initial_prompt = initial_prompt
|
self.initial_prompt = initial_prompt
|
||||||
|
|
||||||
def on_click_run(self):
|
def on_click_run(self):
|
||||||
default_path = get_default_output_file_path(
|
|
||||||
task=self.selected_task, input_file_path=self.file_path,
|
|
||||||
output_format=self.selected_output_format)
|
|
||||||
(output_file, _) = QFileDialog.getSaveFileName(
|
|
||||||
self, 'Save File', default_path, f'Text files (*.{self.selected_output_format.value})')
|
|
||||||
|
|
||||||
if output_file == '':
|
|
||||||
return
|
|
||||||
|
|
||||||
use_whisper_cpp = self.settings.get_enable_ggml_inference(
|
use_whisper_cpp = self.settings.get_enable_ggml_inference(
|
||||||
) and self.selected_language is not None
|
) and self.selected_language is not None
|
||||||
|
|
||||||
|
@ -482,20 +468,18 @@ class FileTranscriberWidget(QWidget):
|
||||||
self.model_loader = ModelLoader(
|
self.model_loader = ModelLoader(
|
||||||
name=model_name, use_whisper_cpp=use_whisper_cpp)
|
name=model_name, use_whisper_cpp=use_whisper_cpp)
|
||||||
|
|
||||||
|
self.transcription_options = FileTranscriptionOptions(
|
||||||
|
file_path=self.file_path, language=self.selected_language,
|
||||||
|
task=self.selected_task, word_level_timings=self.enabled_word_level_timings,
|
||||||
|
temperature=self.temperature, initial_prompt=self.initial_prompt
|
||||||
|
)
|
||||||
|
|
||||||
if use_whisper_cpp:
|
if use_whisper_cpp:
|
||||||
self.file_transcriber = WhisperCppFileTranscriber(
|
self.file_transcriber = WhisperCppFileTranscriber(
|
||||||
file_path=self.file_path, language=self.selected_language, task=self.selected_task,
|
self.transcription_options)
|
||||||
output_file_path=output_file, output_format=self.selected_output_format,
|
|
||||||
word_level_timings=self.enabled_word_level_timings,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.file_transcriber = WhisperFileTranscriber(
|
self.file_transcriber = WhisperFileTranscriber(
|
||||||
file_path=self.file_path, language=self.selected_language, task=self.selected_task,
|
self.transcription_options)
|
||||||
output_file_path=output_file, output_format=self.selected_output_format,
|
|
||||||
word_level_timings=self.enabled_word_level_timings,
|
|
||||||
temperature=self.temperature,
|
|
||||||
initial_prompt=self.initial_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_loader.moveToThread(self.transcriber_thread)
|
self.model_loader.moveToThread(self.transcriber_thread)
|
||||||
self.file_transcriber.moveToThread(self.transcriber_thread)
|
self.file_transcriber.moveToThread(self.transcriber_thread)
|
||||||
|
@ -515,6 +499,7 @@ class FileTranscriberWidget(QWidget):
|
||||||
self.model_loader.finished.connect(self.model_loader.deleteLater)
|
self.model_loader.finished.connect(self.model_loader.deleteLater)
|
||||||
|
|
||||||
# Run the file transcriber after the model loads
|
# Run the file transcriber after the model loads
|
||||||
|
self.model_loader.finished.connect(self.on_model_loaded)
|
||||||
self.model_loader.finished.connect(self.file_transcriber.run)
|
self.model_loader.finished.connect(self.file_transcriber.run)
|
||||||
|
|
||||||
self.file_transcriber.progress.connect(
|
self.file_transcriber.progress.connect(
|
||||||
|
@ -527,6 +512,9 @@ class FileTranscriberWidget(QWidget):
|
||||||
|
|
||||||
self.transcriber_thread.start()
|
self.transcriber_thread.start()
|
||||||
|
|
||||||
|
def on_model_loaded(self):
|
||||||
|
self.is_transcribing = True
|
||||||
|
|
||||||
def on_download_model_progress(self, progress: Tuple[int, int]):
|
def on_download_model_progress(self, progress: Tuple[int, int]):
|
||||||
(current_size, total_size) = progress
|
(current_size, total_size) = progress
|
||||||
|
|
||||||
|
@ -545,22 +533,26 @@ class FileTranscriberWidget(QWidget):
|
||||||
self.reset_transcriber_controls()
|
self.reset_transcriber_controls()
|
||||||
|
|
||||||
def on_transcriber_progress(self, progress: Tuple[int, int]):
|
def on_transcriber_progress(self, progress: Tuple[int, int]):
|
||||||
logging.debug('received progress = %s', progress)
|
|
||||||
(current_size, total_size) = progress
|
(current_size, total_size) = progress
|
||||||
|
|
||||||
# Create a dialog
|
if self.is_transcribing:
|
||||||
if self.transcriber_progress_dialog is None:
|
# Create a dialog
|
||||||
self.transcriber_progress_dialog = TranscriberProgressDialog(
|
if self.transcriber_progress_dialog is None:
|
||||||
file_path=self.file_path, total_size=total_size, parent=self)
|
self.transcriber_progress_dialog = TranscriberProgressDialog(
|
||||||
self.transcriber_progress_dialog.canceled.connect(
|
file_path=self.file_path, total_size=total_size, parent=self)
|
||||||
self.on_cancel_transcriber_progress_dialog)
|
self.transcriber_progress_dialog.canceled.connect(
|
||||||
|
self.on_cancel_transcriber_progress_dialog)
|
||||||
|
else:
|
||||||
|
# Update the progress of the dialog unless it has
|
||||||
|
# been canceled before this progress update arrived
|
||||||
|
self.transcriber_progress_dialog.update_progress(current_size)
|
||||||
|
|
||||||
# Update the progress of the dialog unless it has
|
@pyqtSlot(tuple)
|
||||||
# been canceled before this progress update arrived
|
def on_transcriber_complete(self, result: Tuple[int, List[Segment]]):
|
||||||
if self.transcriber_progress_dialog is not None:
|
exit_code, segments = result
|
||||||
self.transcriber_progress_dialog.update_progress(current_size)
|
|
||||||
|
self.is_transcribing = False
|
||||||
|
|
||||||
def on_transcriber_complete(self, exit_code: int):
|
|
||||||
if self.transcriber_progress_dialog is not None:
|
if self.transcriber_progress_dialog is not None:
|
||||||
self.transcriber_progress_dialog.reset()
|
self.transcriber_progress_dialog.reset()
|
||||||
if exit_code != 0:
|
if exit_code != 0:
|
||||||
|
@ -569,6 +561,10 @@ class FileTranscriberWidget(QWidget):
|
||||||
self.reset_transcriber_controls()
|
self.reset_transcriber_controls()
|
||||||
self.transcribed.emit()
|
self.transcribed.emit()
|
||||||
|
|
||||||
|
TranscriptionViewerWidget(
|
||||||
|
transcription_options=self.transcription_options,
|
||||||
|
segments=segments, parent=self, flags=Qt.WindowType.Window).show()
|
||||||
|
|
||||||
def on_cancel_transcriber_progress_dialog(self):
|
def on_cancel_transcriber_progress_dialog(self):
|
||||||
if self.file_transcriber is not None:
|
if self.file_transcriber is not None:
|
||||||
self.file_transcriber.stop()
|
self.file_transcriber.stop()
|
||||||
|
@ -590,6 +586,70 @@ class FileTranscriberWidget(QWidget):
|
||||||
self.enabled_word_level_timings = value == Qt.CheckState.Checked.value
|
self.enabled_word_level_timings = value == Qt.CheckState.Checked.value
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionViewerWidget(QWidget):
|
||||||
|
segments: List[Segment]
|
||||||
|
transcription_options: FileTranscriptionOptions
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, transcription_options: FileTranscriptionOptions, segments: List[Segment],
|
||||||
|
parent: Optional['QWidget'] = None, flags: Qt.WindowType = Qt.WindowType.Widget,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(parent, flags)
|
||||||
|
self.segments = segments
|
||||||
|
self.transcription_options = transcription_options
|
||||||
|
|
||||||
|
self.setMinimumWidth(500)
|
||||||
|
self.setMinimumHeight(500)
|
||||||
|
|
||||||
|
self.setWindowTitle(
|
||||||
|
f'Transcription - {get_short_file_path(transcription_options.file_path)}')
|
||||||
|
|
||||||
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
|
text = segments_to_text(segments)
|
||||||
|
|
||||||
|
self.text_box = TextDisplayBox(self)
|
||||||
|
self.text_box.setPlainText(text)
|
||||||
|
|
||||||
|
layout.addWidget(self.text_box)
|
||||||
|
|
||||||
|
buttons_layout = QHBoxLayout()
|
||||||
|
buttons_layout.addStretch()
|
||||||
|
|
||||||
|
menu = QMenu()
|
||||||
|
actions = [QAction(text=output_format.value.upper(), parent=self)
|
||||||
|
for output_format in OutputFormat]
|
||||||
|
menu.addActions(actions)
|
||||||
|
|
||||||
|
menu.triggered.connect(self.on_menu_triggered)
|
||||||
|
|
||||||
|
export_button = QPushButton(self)
|
||||||
|
export_button.setText('Export')
|
||||||
|
export_button.setMenu(menu)
|
||||||
|
|
||||||
|
buttons_layout.addWidget(export_button)
|
||||||
|
layout.addLayout(buttons_layout)
|
||||||
|
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def on_menu_triggered(self, action: QAction):
|
||||||
|
output_format = OutputFormat[action.text()]
|
||||||
|
|
||||||
|
default_path = get_default_output_file_path(
|
||||||
|
task=self.transcription_options.task,
|
||||||
|
input_file_path=self.transcription_options.file_path,
|
||||||
|
output_format=output_format)
|
||||||
|
|
||||||
|
(output_file_path, _) = QFileDialog.getSaveFileName(
|
||||||
|
self, 'Save File', default_path, f'Text files (*.{output_format.value})')
|
||||||
|
|
||||||
|
if output_file_path == '':
|
||||||
|
return
|
||||||
|
|
||||||
|
write_output(path=output_file_path, segments=self.segments,
|
||||||
|
should_open=True, output_format=output_format)
|
||||||
|
|
||||||
|
|
||||||
class Settings(QSettings):
|
class Settings(QSettings):
|
||||||
_ENABLE_GGML_INFERENCE = 'enable_ggml_inference'
|
_ENABLE_GGML_INFERENCE = 'enable_ggml_inference'
|
||||||
|
|
||||||
|
@ -686,6 +746,7 @@ class RecordingTranscriberWidget(QWidget):
|
||||||
self.open_advanced_settings)
|
self.open_advanced_settings)
|
||||||
|
|
||||||
self.text_box = TextDisplayBox(self)
|
self.text_box = TextDisplayBox(self)
|
||||||
|
self.text_box.setPlaceholderText('Click Record to begin...')
|
||||||
|
|
||||||
widgets = [
|
widgets = [
|
||||||
((0, 5, FormLabel('Task:', self)), (5, 7, self.tasks_combo_box)),
|
((0, 5, FormLabel('Task:', self)), (5, 7, self.tasks_combo_box)),
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import ctypes
|
import ctypes
|
||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import queue
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
@ -195,28 +197,33 @@ class OutputFormat(enum.Enum):
|
||||||
VTT = 'vtt'
|
VTT = 'vtt'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileTranscriptionOptions:
|
||||||
|
file_path: str
|
||||||
|
language: Optional[str] = None
|
||||||
|
task: Task = Task.TRANSCRIBE
|
||||||
|
word_level_timings: bool = False
|
||||||
|
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
|
||||||
|
initial_prompt: str = ''
|
||||||
|
|
||||||
|
|
||||||
class WhisperCppFileTranscriber(QObject):
|
class WhisperCppFileTranscriber(QObject):
|
||||||
progress = pyqtSignal(tuple) # (current, total)
|
progress = pyqtSignal(tuple) # (current, total)
|
||||||
completed = pyqtSignal(int)
|
completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
|
||||||
error = pyqtSignal(str)
|
error = pyqtSignal(str)
|
||||||
duration_audio_ms = sys.maxsize # max int
|
duration_audio_ms = sys.maxsize # max int
|
||||||
segments: List[Segment]
|
segments: List[Segment]
|
||||||
running = False
|
running = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, language: Optional[str], task: Task, file_path: str,
|
self, transcription_options: FileTranscriptionOptions,
|
||||||
output_file_path: str, output_format: OutputFormat,
|
|
||||||
word_level_timings: bool, open_file_on_complete=True,
|
|
||||||
parent: Optional['QObject'] = None) -> None:
|
parent: Optional['QObject'] = None) -> None:
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
|
|
||||||
self.file_path = file_path
|
self.file_path = transcription_options.file_path
|
||||||
self.output_file_path = output_file_path
|
self.language = transcription_options.language
|
||||||
self.language = language
|
self.task = transcription_options.task
|
||||||
self.task = task
|
self.word_level_timings = transcription_options.word_level_timings
|
||||||
self.open_file_on_complete = open_file_on_complete
|
|
||||||
self.output_format = output_format
|
|
||||||
self.word_level_timings = word_level_timings
|
|
||||||
self.segments = []
|
self.segments = []
|
||||||
|
|
||||||
self.process = QProcess(self)
|
self.process = QProcess(self)
|
||||||
|
@ -228,8 +235,8 @@ class WhisperCppFileTranscriber(QObject):
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model_path = %s',
|
'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s',
|
||||||
self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path)
|
self.file_path, self.language, self.task, model_path)
|
||||||
|
|
||||||
wav_file = tempfile.mktemp()+'.wav'
|
wav_file = tempfile.mktemp()+'.wav'
|
||||||
(
|
(
|
||||||
|
@ -258,10 +265,8 @@ class WhisperCppFileTranscriber(QObject):
|
||||||
if status == QProcess.ExitStatus.NormalExit:
|
if status == QProcess.ExitStatus.NormalExit:
|
||||||
self.progress.emit(
|
self.progress.emit(
|
||||||
(self.duration_audio_ms, self.duration_audio_ms))
|
(self.duration_audio_ms, self.duration_audio_ms))
|
||||||
write_output(
|
|
||||||
self.output_file_path, self.segments, self.open_file_on_complete, self.output_format)
|
|
||||||
|
|
||||||
self.completed.emit(status == self.process.exitCode())
|
self.completed.emit((self.process.exitCode(), self.segments))
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
@ -282,7 +287,7 @@ class WhisperCppFileTranscriber(QObject):
|
||||||
segment = Segment(start, end, text.strip())
|
segment = Segment(start, end, text.strip())
|
||||||
self.segments.append(segment)
|
self.segments.append(segment)
|
||||||
self.progress.emit((end, self.duration_audio_ms))
|
self.progress.emit((end, self.duration_audio_ms))
|
||||||
except UnicodeDecodeError:
|
except (UnicodeDecodeError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def parse_timings(self, timings: str) -> Tuple[int, int]:
|
def parse_timings(self, timings: str) -> Tuple[int, int]:
|
||||||
|
@ -315,44 +320,39 @@ class WhisperFileTranscriber(QObject):
|
||||||
|
|
||||||
current_process: multiprocessing.Process
|
current_process: multiprocessing.Process
|
||||||
progress = pyqtSignal(tuple) # (current, total)
|
progress = pyqtSignal(tuple) # (current, total)
|
||||||
completed = pyqtSignal(int)
|
completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
|
||||||
error = pyqtSignal(str)
|
error = pyqtSignal(str)
|
||||||
running = False
|
running = False
|
||||||
read_line_thread: Optional[Thread] = None
|
read_line_thread: Optional[Thread] = None
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
language: Optional[str], task: Task, file_path: str,
|
self, transcription_options: FileTranscriptionOptions,
|
||||||
output_file_path: str, output_format: OutputFormat,
|
parent: Optional['QObject'] = None) -> None:
|
||||||
word_level_timings: bool, temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE,
|
|
||||||
initial_prompt: str = '', open_file_on_complete=True, parent: Optional['QObject'] = None) -> None:
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
|
|
||||||
self.file_path = file_path
|
self.file_path = transcription_options.file_path
|
||||||
self.output_file_path = output_file_path
|
self.language = transcription_options.language
|
||||||
self.language = language
|
self.task = transcription_options.task
|
||||||
self.task = task
|
self.word_level_timings = transcription_options.word_level_timings
|
||||||
self.open_file_on_complete = open_file_on_complete
|
self.temperature = transcription_options.temperature
|
||||||
self.output_format = output_format
|
self.initial_prompt = transcription_options.initial_prompt
|
||||||
self.word_level_timings = word_level_timings
|
self.segments = []
|
||||||
self.temperature = temperature
|
|
||||||
self.initial_prompt = initial_prompt
|
|
||||||
|
|
||||||
@pyqtSlot(str)
|
@pyqtSlot(str)
|
||||||
def run(self, model_path: str):
|
def run(self, model_path: str):
|
||||||
self.running = True
|
self.running = True
|
||||||
time_started = datetime.datetime.now()
|
time_started = datetime.datetime.now()
|
||||||
logging.debug(
|
logging.debug(
|
||||||
'Starting file transcription, file path = %s, language = %s, task = %s, output file path = %s, output format = %s, model path = %s, temperature = %s, initial prompt length = %s',
|
'Starting whisper file transcription, file path = %s, language = %s, task = %s, model path = %s, temperature = %s, initial prompt length = %s',
|
||||||
self.file_path, self.language, self.task, self.output_file_path, self.output_format, model_path, self.temperature, len(self.initial_prompt))
|
self.file_path, self.language, self.task, model_path, self.temperature, len(self.initial_prompt))
|
||||||
|
|
||||||
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
||||||
|
|
||||||
self.current_process = multiprocessing.Process(
|
self.current_process = multiprocessing.Process(
|
||||||
target=transcribe_whisper,
|
target=transcribe_whisper,
|
||||||
args=(
|
args=(
|
||||||
send_pipe, model_path, self.file_path,
|
send_pipe, model_path, self.file_path,
|
||||||
self.language, self.task, self.output_file_path,
|
self.language, self.task, self.word_level_timings,
|
||||||
self.open_file_on_complete, self.output_format,
|
|
||||||
self.word_level_timings,
|
|
||||||
self.temperature, self.initial_prompt
|
self.temperature, self.initial_prompt
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@ -373,7 +373,9 @@ class WhisperFileTranscriber(QObject):
|
||||||
|
|
||||||
self.read_line_thread.join()
|
self.read_line_thread.join()
|
||||||
|
|
||||||
self.completed.emit(self.current_process.exitcode)
|
if self.current_process.exitcode != 0:
|
||||||
|
self.completed.emit((self.current_process.exitcode, []))
|
||||||
|
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
@ -381,6 +383,17 @@ class WhisperFileTranscriber(QObject):
|
||||||
self.current_process.terminate()
|
self.current_process.terminate()
|
||||||
|
|
||||||
def on_whisper_stdout(self, line: str):
|
def on_whisper_stdout(self, line: str):
|
||||||
|
if line.startswith('segments = '):
|
||||||
|
segments_dict = json.loads(line[11:])
|
||||||
|
segments = [Segment(
|
||||||
|
start=segment.get('start'),
|
||||||
|
end=segment.get('end'),
|
||||||
|
text=segment.get('text'),
|
||||||
|
) for segment in segments_dict]
|
||||||
|
self.current_process.join()
|
||||||
|
self.completed.emit((self.current_process.exitcode, segments))
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
progress = int(line.split('|')[0].strip().strip('%'))
|
progress = int(line.split('|')[0].strip().strip('%'))
|
||||||
self.progress.emit((progress, 100))
|
self.progress.emit((progress, 100))
|
||||||
|
@ -398,8 +411,7 @@ class WhisperFileTranscriber(QObject):
|
||||||
|
|
||||||
def transcribe_whisper(
|
def transcribe_whisper(
|
||||||
stderr_conn: Connection, model_path: str, file_path: str,
|
stderr_conn: Connection, model_path: str, file_path: str,
|
||||||
language: Optional[str], task: Task, output_file_path: str,
|
language: Optional[str], task: Task,
|
||||||
open_file_on_complete: bool, output_format: OutputFormat,
|
|
||||||
word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str):
|
word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str):
|
||||||
with pipe_stderr(stderr_conn):
|
with pipe_stderr(stderr_conn):
|
||||||
model = whisper.load_model(model_path)
|
model = whisper.load_model(model_path)
|
||||||
|
@ -418,15 +430,15 @@ def transcribe_whisper(
|
||||||
whisper_segments = stable_whisper.group_word_timestamps(
|
whisper_segments = stable_whisper.group_word_timestamps(
|
||||||
result) if word_level_timings else result.get('segments')
|
result) if word_level_timings else result.get('segments')
|
||||||
|
|
||||||
segments = map(
|
segments = [
|
||||||
lambda segment: Segment(
|
Segment(
|
||||||
start=segment.get('start')*1000, # s to ms
|
start=int(segment.get('start')*1000),
|
||||||
end=segment.get('end')*1000, # s to ms
|
end=int(segment.get('end')*1000),
|
||||||
text=segment.get('text')),
|
text=segment.get('text'),
|
||||||
whisper_segments)
|
) for segment in whisper_segments]
|
||||||
|
segments_json = json.dumps(
|
||||||
write_output(output_file_path, list(
|
segments, ensure_ascii=True, default=vars)
|
||||||
segments), open_file_on_complete, output_format)
|
sys.stderr.write(f'segments = {segments_json}\n')
|
||||||
|
|
||||||
|
|
||||||
def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat):
|
def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat):
|
||||||
|
@ -435,8 +447,11 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f
|
||||||
|
|
||||||
with open(path, 'w', encoding='utf-8') as file:
|
with open(path, 'w', encoding='utf-8') as file:
|
||||||
if output_format == OutputFormat.TXT:
|
if output_format == OutputFormat.TXT:
|
||||||
for segment in segments:
|
for (i, segment) in enumerate(segments):
|
||||||
file.write(segment.text + ' ')
|
file.write(segment.text)
|
||||||
|
if i < len(segments)-1:
|
||||||
|
file.write(' ')
|
||||||
|
file.write('\n')
|
||||||
|
|
||||||
elif output_format == OutputFormat.VTT:
|
elif output_format == OutputFormat.VTT:
|
||||||
file.write('WEBVTT\n\n')
|
file.write('WEBVTT\n\n')
|
||||||
|
@ -460,6 +475,16 @@ def write_output(path: str, segments: List[Segment], should_open: bool, output_f
|
||||||
subprocess.call([opener, path])
|
subprocess.call([opener, path])
|
||||||
|
|
||||||
|
|
||||||
|
def segments_to_text(segments: List[Segment]) -> str:
|
||||||
|
result = ''
|
||||||
|
for (i, segment) in enumerate(segments):
|
||||||
|
result += f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n'
|
||||||
|
result += f'{segment.text}'
|
||||||
|
if i < len(segments)-1:
|
||||||
|
result += '\n\n'
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def to_timestamp(ms: float) -> str:
|
def to_timestamp(ms: float) -> str:
|
||||||
hr = int(ms / (1000*60*60))
|
hr = int(ms / (1000*60*60))
|
||||||
ms = ms - hr * (1000*60*60)
|
ms = ms - hr * (1000*60*60)
|
||||||
|
|
|
@ -1,21 +1,24 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from typing import Any, Callable
|
||||||
|
import platform
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import sounddevice
|
import sounddevice
|
||||||
from PyQt6.QtCore import Qt, QCoreApplication
|
from PyQt6.QtCore import Qt, QCoreApplication
|
||||||
from PyQt6.QtGui import (QValidator)
|
from PyQt6.QtGui import (QValidator)
|
||||||
|
from PyQt6.QtWidgets import (QPushButton)
|
||||||
from pytestqt.qtbot import QtBot
|
from pytestqt.qtbot import QtBot
|
||||||
|
|
||||||
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
|
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
|
||||||
AudioDevicesComboBox, DownloadModelProgressDialog,
|
AudioDevicesComboBox, DownloadModelProgressDialog,
|
||||||
FileTranscriberWidget, LanguagesComboBox, MainWindow,
|
FileTranscriberWidget, LanguagesComboBox, MainWindow,
|
||||||
OutputFormatsComboBox, Quality, QualityComboBox,
|
OutputFormatsComboBox, Quality, QualityComboBox,
|
||||||
Settings, TemperatureValidator,
|
Settings, TemperatureValidator, TextDisplayBox,
|
||||||
TranscriberProgressDialog)
|
TranscriberProgressDialog, TranscriptionViewerWidget)
|
||||||
from buzz.transcriber import OutputFormat
|
from buzz.transcriber import FileTranscriptionOptions, OutputFormat, Segment
|
||||||
|
|
||||||
|
|
||||||
class TestApplication:
|
class TestApplication:
|
||||||
|
@ -142,25 +145,29 @@ class TestTranscriberProgressDialog:
|
||||||
|
|
||||||
|
|
||||||
class TestDownloadModelProgressDialog:
|
class TestDownloadModelProgressDialog:
|
||||||
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
|
def test_should_show_dialog(self, qtbot: QtBot):
|
||||||
|
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
|
||||||
|
qtbot.add_widget(dialog)
|
||||||
|
assert dialog.labelText() == 'Downloading resources (0%, unknown time remaining)'
|
||||||
|
|
||||||
def test_should_show_dialog(self):
|
def test_should_update_label_on_progress(self, qtbot: QtBot):
|
||||||
assert self.dialog.labelText() == 'Downloading resources (0%, unknown time remaining)'
|
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
|
||||||
|
qtbot.add_widget(dialog)
|
||||||
|
dialog.setValue(0)
|
||||||
|
|
||||||
def test_should_update_label_on_progress(self):
|
dialog.setValue(12345)
|
||||||
self.dialog.setValue(0)
|
assert dialog.labelText().startswith(
|
||||||
|
|
||||||
self.dialog.setValue(12345)
|
|
||||||
assert self.dialog.labelText().startswith(
|
|
||||||
'Downloading resources (1.00%')
|
'Downloading resources (1.00%')
|
||||||
|
|
||||||
self.dialog.setValue(123456)
|
dialog.setValue(123456)
|
||||||
assert self.dialog.labelText().startswith(
|
assert dialog.labelText().startswith(
|
||||||
'Downloading resources (10.00%')
|
'Downloading resources (10.00%')
|
||||||
|
|
||||||
# Other windows should not be processing while models are being downloaded
|
# Other windows should not be processing while models are being downloaded
|
||||||
def test_should_be_an_application_modal(self):
|
def test_should_be_an_application_modal(self, qtbot: QtBot):
|
||||||
assert self.dialog.windowModality() == Qt.WindowModality.ApplicationModal
|
dialog = DownloadModelProgressDialog(total_size=1234567, parent=None)
|
||||||
|
qtbot.add_widget(dialog)
|
||||||
|
assert dialog.windowModality() == Qt.WindowModality.ApplicationModal
|
||||||
|
|
||||||
|
|
||||||
class TestFormatsComboBox:
|
class TestFormatsComboBox:
|
||||||
|
@ -177,21 +184,32 @@ class TestMainWindow:
|
||||||
assert main_window is not None
|
assert main_window is not None
|
||||||
|
|
||||||
|
|
||||||
|
def wait_until(callback: Callable[[], Any], timeout=0):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
QCoreApplication.processEvents()
|
||||||
|
callback()
|
||||||
|
return
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestFileTranscriberWidget:
|
class TestFileTranscriberWidget:
|
||||||
def test_should_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path):
|
@pytest.mark.skipif(condition=platform.system() == 'Windows', reason='Waiting for signal crashes process on Windows')
|
||||||
|
def test_should_transcribe(self, qtbot: QtBot):
|
||||||
widget = FileTranscriberWidget(
|
widget = FileTranscriberWidget(
|
||||||
file_path='testdata/whisper-french.mp3', parent=None)
|
file_path='testdata/whisper-french.mp3', parent=None)
|
||||||
qtbot.addWidget(widget)
|
qtbot.addWidget(widget)
|
||||||
|
|
||||||
output_file_path = tmp_path / 'whisper.txt'
|
# Waiting for a "transcribed" signal seems to work more consistently
|
||||||
|
# than checking for the opening of a TranscriptionViewerWidget.
|
||||||
|
# See also: https://github.com/pytest-dev/pytest-qt/issues/313
|
||||||
|
with qtbot.wait_signal(widget.transcribed, timeout=30*1000):
|
||||||
|
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
|
||||||
|
|
||||||
with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock,
|
transcription_viewer = widget.findChild(TranscriptionViewerWidget)
|
||||||
qtbot.wait_signal(widget.transcribed, timeout=10*1000)):
|
assert isinstance(transcription_viewer, TranscriptionViewerWidget)
|
||||||
save_file_name_mock.return_value = (output_file_path, '')
|
assert len(transcription_viewer.segments) > 0
|
||||||
widget.run_button.click()
|
|
||||||
|
|
||||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
|
||||||
assert 'Bienvenue dans Passe' in output_file.read()
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="transcription_started callback sometimes not getting called until all progress events are emitted")
|
@pytest.mark.skip(reason="transcription_started callback sometimes not getting called until all progress events are emitted")
|
||||||
def test_should_transcribe_and_stop(self, qtbot: QtBot, tmp_path: pathlib.Path):
|
def test_should_transcribe_and_stop(self, qtbot: QtBot, tmp_path: pathlib.Path):
|
||||||
|
@ -202,13 +220,12 @@ class TestFileTranscriberWidget:
|
||||||
output_file_path = tmp_path / 'whisper.txt'
|
output_file_path = tmp_path / 'whisper.txt'
|
||||||
|
|
||||||
with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock):
|
with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock):
|
||||||
save_file_name_mock.return_value = (output_file_path, '')
|
save_file_name_mock.return_value = (str(output_file_path), '')
|
||||||
widget.run_button.click()
|
widget.run_button.click()
|
||||||
|
|
||||||
def transcription_started():
|
def transcription_started():
|
||||||
QCoreApplication.processEvents()
|
QCoreApplication.processEvents()
|
||||||
assert widget.transcriber_progress_dialog is not None
|
assert widget.transcriber_progress_dialog is not None
|
||||||
logging.debug('asserted value = %s', widget.transcriber_progress_dialog.value())
|
|
||||||
assert widget.transcriber_progress_dialog.value() > 0
|
assert widget.transcriber_progress_dialog.value() > 0
|
||||||
qtbot.wait_until(transcription_started, timeout=30*1000)
|
qtbot.wait_until(transcription_started, timeout=30*1000)
|
||||||
|
|
||||||
|
@ -270,3 +287,30 @@ class TestTemperatureValidator:
|
||||||
])
|
])
|
||||||
def test_should_validate_temperature(self, text: str, state: QValidator.State):
|
def test_should_validate_temperature(self, text: str, state: QValidator.State):
|
||||||
assert self.validator.validate(text, 0)[0] == state
|
assert self.validator.validate(text, 0)[0] == state
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptionViewerWidget:
|
||||||
|
widget = TranscriptionViewerWidget(
|
||||||
|
transcription_options=FileTranscriptionOptions(
|
||||||
|
file_path='testdata/whisper-french.mp3'),
|
||||||
|
segments=[Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')])
|
||||||
|
|
||||||
|
def test_should_display_segments(self):
|
||||||
|
assert self.widget.windowTitle() == 'Transcription - whisper-french.mp3'
|
||||||
|
|
||||||
|
text_display_box = self.widget.findChild(TextDisplayBox)
|
||||||
|
assert isinstance(text_display_box, TextDisplayBox)
|
||||||
|
assert text_display_box.toPlainText(
|
||||||
|
) == '00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans'
|
||||||
|
|
||||||
|
def test_should_export_segments(self, tmp_path: pathlib.Path):
|
||||||
|
export_button = self.widget.findChild(QPushButton)
|
||||||
|
assert isinstance(export_button, QPushButton)
|
||||||
|
|
||||||
|
output_file_path = tmp_path / 'whisper.txt'
|
||||||
|
with patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock:
|
||||||
|
save_file_name_mock.return_value = (str(output_file_path), '')
|
||||||
|
export_button.menu().actions()[0].trigger()
|
||||||
|
|
||||||
|
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||||
|
assert 'Bien venue dans' in output_file.read()
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
from typing import List
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from PyQt6.QtCore import QCoreApplication
|
from PyQt6.QtCore import QCoreApplication
|
||||||
|
@ -9,11 +11,11 @@ import pytest
|
||||||
from pytestqt.qtbot import QtBot
|
from pytestqt.qtbot import QtBot
|
||||||
|
|
||||||
from buzz.model_loader import ModelLoader
|
from buzz.model_loader import ModelLoader
|
||||||
from buzz.transcriber import (OutputFormat, RecordingTranscriber, Task,
|
from buzz.transcriber import (FileTranscriptionOptions, OutputFormat, RecordingTranscriber, Segment, Task,
|
||||||
WhisperCpp, WhisperCppFileTranscriber,
|
WhisperCpp, WhisperCppFileTranscriber,
|
||||||
WhisperFileTranscriber,
|
WhisperFileTranscriber,
|
||||||
get_default_output_file_path, to_timestamp,
|
get_default_output_file_path, to_timestamp,
|
||||||
whisper_cpp_params)
|
whisper_cpp_params, write_output)
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_name: str, use_whisper_cpp: bool) -> str:
|
def get_model_path(model_name: str, use_whisper_cpp: bool) -> str:
|
||||||
|
@ -40,33 +42,31 @@ class TestRecordingTranscriber:
|
||||||
|
|
||||||
class TestWhisperCppFileTranscriber:
|
class TestWhisperCppFileTranscriber:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'task,output_text',
|
'word_level_timings,expected_segments',
|
||||||
[
|
[
|
||||||
(Task.TRANSCRIBE, 'Bienvenue dans Passe'),
|
(False, [Segment(0, 1840, 'Bienvenue dans Passe Relle.')]),
|
||||||
(Task.TRANSLATE, 'Welcome to Passe-Relle'),
|
(True, [Segment(30, 280, 'Bien'), Segment(280, 630, 'venue')])
|
||||||
])
|
])
|
||||||
def test_transcribe(self, qtbot, tmp_path: pathlib.Path, task: Task, output_text: str):
|
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
|
||||||
output_file_path = tmp_path / 'whisper_cpp.txt'
|
transcription_options = FileTranscriptionOptions(
|
||||||
if os.path.exists(output_file_path):
|
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||||
os.remove(output_file_path)
|
word_level_timings=word_level_timings)
|
||||||
|
|
||||||
model_path = get_model_path('tiny', True)
|
model_path = get_model_path('tiny', True)
|
||||||
transcriber = WhisperCppFileTranscriber(
|
transcriber = WhisperCppFileTranscriber(
|
||||||
language='fr', task=task, file_path='testdata/whisper-french.mp3',
|
transcription_options=transcription_options)
|
||||||
output_file_path=output_file_path.as_posix(), output_format=OutputFormat.TXT,
|
|
||||||
open_file_on_complete=False,
|
|
||||||
word_level_timings=False)
|
|
||||||
mock_progress = Mock()
|
mock_progress = Mock()
|
||||||
with qtbot.waitSignal(transcriber.completed, timeout=10*60*1000):
|
mock_completed = Mock()
|
||||||
transcriber.progress.connect(mock_progress)
|
transcriber.progress.connect(mock_progress)
|
||||||
|
transcriber.completed.connect(mock_completed)
|
||||||
|
with qtbot.waitSignal(transcriber.completed, timeout=10 * 60 * 1000):
|
||||||
transcriber.run(model_path)
|
transcriber.run(model_path)
|
||||||
|
|
||||||
assert os.path.isfile(output_file_path)
|
|
||||||
|
|
||||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
|
||||||
assert output_text in output_file.read()
|
|
||||||
|
|
||||||
mock_progress.assert_called()
|
mock_progress.assert_called()
|
||||||
|
exit_code, segments = mock_completed.call_args[0][0]
|
||||||
|
assert exit_code is 0
|
||||||
|
for expected_segment in expected_segments:
|
||||||
|
assert expected_segment in segments
|
||||||
|
|
||||||
|
|
||||||
class TestWhisperFileTranscriber:
|
class TestWhisperFileTranscriber:
|
||||||
|
@ -82,36 +82,31 @@ class TestWhisperFileTranscriber:
|
||||||
assert srt.endswith('.srt')
|
assert srt.endswith('.srt')
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'word_level_timings,output_format,output_text',
|
'word_level_timings,expected_segments',
|
||||||
[
|
[
|
||||||
(False, OutputFormat.TXT, 'Bienvenue dans Passe-Relle'),
|
(False, [
|
||||||
(False, OutputFormat.SRT,
|
Segment(
|
||||||
'1\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe-Relle'),
|
0, 6560,
|
||||||
(False, OutputFormat.VTT,
|
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances'),
|
||||||
'WEBVTT\n\n00:00:00.000 --> 00:00:06.560\n Bienvenue dans Passe'),
|
]),
|
||||||
(True, OutputFormat.SRT,
|
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')])
|
||||||
'1\n00:00:00.040 --> 00:00:00.299\n Bien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n3\n00:00:00.329 --> 00:00:00.429\n P\n\n4\n00:00:00.429 --> 00:00:00.589\nasse-'),
|
|
||||||
])
|
])
|
||||||
def test_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path, word_level_timings: bool, output_format: OutputFormat, output_text: str):
|
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
|
||||||
output_file_path = tmp_path / f'whisper.{output_format.value.lower()}'
|
|
||||||
|
|
||||||
model_path = get_model_path('tiny', False)
|
model_path = get_model_path('tiny', False)
|
||||||
|
|
||||||
mock_progress = Mock()
|
mock_progress = Mock()
|
||||||
transcriber = WhisperFileTranscriber(
|
mock_completed = Mock()
|
||||||
|
transcription_options = FileTranscriptionOptions(
|
||||||
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||||
output_file_path=output_file_path.as_posix(), output_format=output_format,
|
|
||||||
open_file_on_complete=False,
|
|
||||||
word_level_timings=word_level_timings)
|
word_level_timings=word_level_timings)
|
||||||
|
|
||||||
|
transcriber = WhisperFileTranscriber(
|
||||||
|
transcription_options=transcription_options)
|
||||||
transcriber.progress.connect(mock_progress)
|
transcriber.progress.connect(mock_progress)
|
||||||
with qtbot.wait_signal(transcriber.completed, timeout=10*6000):
|
transcriber.completed.connect(mock_completed)
|
||||||
|
with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
|
||||||
transcriber.run(model_path)
|
transcriber.run(model_path)
|
||||||
|
|
||||||
assert os.path.isfile(output_file_path)
|
|
||||||
|
|
||||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
|
||||||
assert output_text in output_file.read()
|
|
||||||
|
|
||||||
QCoreApplication.processEvents()
|
QCoreApplication.processEvents()
|
||||||
|
|
||||||
# Reports progress at 0, 0<progress<100, and 100
|
# Reports progress at 0, 0<progress<100, and 100
|
||||||
|
@ -120,7 +115,14 @@ class TestWhisperFileTranscriber:
|
||||||
assert any(
|
assert any(
|
||||||
[call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
|
[call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
|
||||||
assert any(
|
assert any(
|
||||||
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in mock_progress.call_args_list])
|
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in
|
||||||
|
mock_progress.call_args_list])
|
||||||
|
|
||||||
|
mock_completed.assert_called()
|
||||||
|
exit_code, segments = mock_completed.call_args[0][0]
|
||||||
|
assert exit_code is 0
|
||||||
|
for (i, expected_segment) in enumerate(expected_segments):
|
||||||
|
assert segments[i] == expected_segment
|
||||||
|
|
||||||
@pytest.mark.skip()
|
@pytest.mark.skip()
|
||||||
def test_transcribe_stop(self):
|
def test_transcribe_stop(self):
|
||||||
|
@ -129,11 +131,12 @@ class TestWhisperFileTranscriber:
|
||||||
os.remove(output_file_path)
|
os.remove(output_file_path)
|
||||||
|
|
||||||
model_path = get_model_path('tiny', False)
|
model_path = get_model_path('tiny', False)
|
||||||
transcriber = WhisperFileTranscriber(
|
transcription_options = FileTranscriptionOptions(
|
||||||
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
language='fr', task=Task.TRANSCRIBE, file_path='testdata/whisper-french.mp3',
|
||||||
output_file_path=output_file_path, output_format=OutputFormat.TXT,
|
|
||||||
open_file_on_complete=False,
|
|
||||||
word_level_timings=False)
|
word_level_timings=False)
|
||||||
|
|
||||||
|
transcriber = WhisperFileTranscriber(
|
||||||
|
transcription_options=transcription_options)
|
||||||
transcriber.run(model_path)
|
transcriber.run(model_path)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
transcriber.stop()
|
transcriber.stop()
|
||||||
|
@ -159,3 +162,24 @@ class TestWhisperCpp:
|
||||||
audio='testdata/whisper-french.mp3', params=params)
|
audio='testdata/whisper-french.mp3', params=params)
|
||||||
|
|
||||||
assert 'Bienvenue dans Passe' in result['text']
|
assert 'Bienvenue dans Passe' in result['text']
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'output_format,output_text',
|
||||||
|
[
|
||||||
|
(OutputFormat.TXT, 'Bien venue dans\n'),
|
||||||
|
(
|
||||||
|
OutputFormat.SRT,
|
||||||
|
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||||
|
(OutputFormat.VTT,
|
||||||
|
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||||
|
])
|
||||||
|
def test_write_output(tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str):
|
||||||
|
output_file_path = tmp_path / 'whisper.txt'
|
||||||
|
segments = [Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')]
|
||||||
|
|
||||||
|
write_output(path=str(output_file_path), segments=segments,
|
||||||
|
should_open=False, output_format=output_format)
|
||||||
|
|
||||||
|
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||||
|
assert output_text == output_file.read()
|
||||||
|
|
Loading…
Reference in a new issue