mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-29 13:10:26 +02:00
Add task queue (#253)
This commit is contained in:
parent
acd2d93e69
commit
0e086bd593
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -11,3 +11,4 @@ whisper_cpp
|
|||
whisper.dll
|
||||
whisper_cpp.py
|
||||
coverage.xml
|
||||
.idea/
|
||||
|
|
2
Makefile
2
Makefile
|
@ -43,7 +43,7 @@ clean:
|
|||
rm -rf dist/* || true
|
||||
|
||||
test: buzz/whisper_cpp.py
|
||||
pytest --cov --cov-report=xml
|
||||
pytest --cov=buzz --cov-report=xml --cov-report=html
|
||||
|
||||
dist/Buzz dist/Buzz.app: buzz/whisper_cpp.py
|
||||
pyinstaller --noconfirm Buzz.spec
|
||||
|
|
5
assets/circle-plus-icon.svg
Normal file
5
assets/circle-plus-icon.svg
Normal file
|
@ -0,0 +1,5 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 512 512"><!--! Font Awesome Pro 6.2.1 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2022 Fonticons, Inc. -->
|
||||
<path fill="#888"
|
||||
d="M256 512c141.4 0 256-114.6 256-256S397.4 0 256 0S0 114.6 0 256S114.6 512 256 512zM232 344V280H168c-13.3 0-24-10.7-24-24s10.7-24 24-24h64V168c0-13.3 10.7-24 24-24s24 10.7 24 24v64h64c13.3 0 24 10.7 24 24s-10.7 24-24 24H280v64c0 13.3-10.7 24-24 24s-24-10.7-24-24z"/>
|
||||
</svg>
|
After Width: | Height: | Size: 542 B |
5
assets/record-icon.svg
Normal file
5
assets/record-icon.svg
Normal file
|
@ -0,0 +1,5 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 384 512"><!--! Font Awesome Pro 6.2.1 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2022 Fonticons, Inc. -->
|
||||
<path fill="#888"
|
||||
d="M192 0C139 0 96 43 96 96V256c0 53 43 96 96 96s96-43 96-96V96c0-53-43-96-96-96zM64 216c0-13.3-10.7-24-24-24s-24 10.7-24 24v40c0 89.1 66.2 162.7 152 174.4V464H120c-13.3 0-24 10.7-24 24s10.7 24 24 24h72 72c13.3 0 24-10.7 24-24s-10.7-24-24-24H216V430.4c85.8-11.7 152-85.3 152-174.4V216c0-13.3-10.7-24-24-24s-24 10.7-24 24v40c0 70.7-57.3 128-128 128s-128-57.3-128-128V216z"/>
|
||||
</svg>
|
After Width: | Height: | Size: 648 B |
5
assets/up-down-and-down-left-from-center-icon.svg
Normal file
5
assets/up-down-and-down-left-from-center-icon.svg
Normal file
|
@ -0,0 +1,5 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 512 512"><!--! Font Awesome Pro 6.2.1 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2022 Fonticons, Inc. -->
|
||||
<path fill="#888"
|
||||
d="M344 0H488c13.3 0 24 10.7 24 24V168c0 9.7-5.8 18.5-14.8 22.2s-19.3 1.7-26.2-5.2l-39-39-87 87c-9.4 9.4-24.6 9.4-33.9 0l-32-32c-9.4-9.4-9.4-24.6 0-33.9l87-87L327 41c-6.9-6.9-8.9-17.2-5.2-26.2S334.3 0 344 0zM184 496H40c-13.3 0-24-10.7-24-24V328c0-9.7 5.8-18.5 14.8-22.2s19.3-1.7 26.2 5.2l39 39 87-87c9.4-9.4 24.6-9.4 33.9 0l32 32c9.4 9.4 9.4 24.6 0 33.9l-87 87 39 39c6.9 6.9 8.9 17.2 5.2 26.2s-12.5 14.8-22.2 14.8z"/>
|
||||
</svg>
|
After Width: | Height: | Size: 692 B |
10
buzz/conn.py
10
buzz/conn.py
|
@ -19,13 +19,3 @@ def pipe_stderr(conn: Connection):
|
|||
yield
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pipe_stdout(conn: Connection):
|
||||
sys.stdout = ConnWriter(conn)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout = sys.__stdout__
|
||||
|
|
476
buzz/gui.py
476
buzz/gui.py
|
@ -9,15 +9,16 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
import humanize
|
||||
import sounddevice
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QThread, pyqtSlot,
|
||||
QThreadPool, QTimer, QUrl, pyqtSignal)
|
||||
from PyQt6.QtCore import (QDateTime, QObject, QSettings, Qt, QThread, pyqtSlot,
|
||||
QTimer, QUrl, pyqtSignal, QModelIndex, QSize)
|
||||
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
||||
QKeySequence, QPixmap, QTextCursor, QValidator)
|
||||
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
||||
QDialogButtonBox, QFileDialog, QLabel, QLineEdit,
|
||||
QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu,
|
||||
QWidget, QGroupBox, QFormLayout)
|
||||
QWidget, QGroupBox, QToolBar, QTableWidget, QMenuBar, QFormLayout, QTableWidgetItem,
|
||||
QHeaderView, QAbstractItemView)
|
||||
from requests import get
|
||||
from whisper import tokenizer
|
||||
|
||||
|
@ -27,7 +28,7 @@ from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Ou
|
|||
RecordingTranscriber, Segment, Task,
|
||||
WhisperCppFileTranscriber, WhisperFileTranscriber,
|
||||
get_default_output_file_path, segments_to_text, write_output, TranscriptionOptions,
|
||||
Model)
|
||||
Model, FileTranscriberQueueWorker, FileTranscriptionTask)
|
||||
|
||||
APP_NAME = 'Buzz'
|
||||
|
||||
|
@ -308,31 +309,42 @@ class TimerLabel(QLabel):
|
|||
|
||||
|
||||
def show_model_download_error_dialog(parent: QWidget, error: str):
|
||||
message = f'Unable to load the Whisper model: {error}. Please retry or check the application logs for more information.'
|
||||
message = f'Unable to load the Whisper model: {error}. Please retry or check the application logs for more ' \
|
||||
f'information. '
|
||||
QMessageBox.critical(parent, '', message)
|
||||
|
||||
|
||||
class FileTranscriberWidget(QWidget):
|
||||
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None
|
||||
file_transcriber: Optional[Union[WhisperFileTranscriber, WhisperCppFileTranscriber]] = None
|
||||
file_transcriber: Optional[Union[WhisperFileTranscriber,
|
||||
WhisperCppFileTranscriber]] = None
|
||||
model_loader: Optional[ModelLoader] = None
|
||||
transcriber_thread: Optional[QThread] = None
|
||||
file_transcription_options: FileTranscriptionOptions
|
||||
transcription_options: TranscriptionOptions
|
||||
is_transcribing = False
|
||||
# (TranscriptionOptions, FileTranscriptionOptions, str)
|
||||
triggered = pyqtSignal(tuple)
|
||||
|
||||
def __init__(self, file_path: str, parent: Optional[QWidget] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
|
||||
super().__init__(parent, flags)
|
||||
|
||||
self.setWindowTitle(get_short_file_path(file_path))
|
||||
self.setFixedSize(420, 270)
|
||||
|
||||
def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
|
||||
super().__init__(parent)
|
||||
self.file_path = file_path
|
||||
self.transcription_options = TranscriptionOptions()
|
||||
self.file_transcription_options = FileTranscriptionOptions(file_path=self.file_path)
|
||||
self.file_transcription_options = FileTranscriptionOptions(
|
||||
file_path=self.file_path)
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
transcription_options_group_box = TranscriptionOptionsGroupBox(
|
||||
default_transcription_options=self.transcription_options, parent=self)
|
||||
transcription_options_group_box.transcription_options_changed.connect(self.on_transcription_options_changed)
|
||||
transcription_options_group_box.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
|
||||
self.word_level_timings_checkbox = QCheckBox('Word-level timings')
|
||||
self.word_level_timings_checkbox.stateChanged.connect(
|
||||
|
@ -350,7 +362,6 @@ class FileTranscriberWidget(QWidget):
|
|||
layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.pool = QThreadPool()
|
||||
|
||||
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
|
||||
self.transcription_options = transcription_options
|
||||
|
@ -359,50 +370,31 @@ class FileTranscriberWidget(QWidget):
|
|||
self.run_button.setDisabled(True)
|
||||
|
||||
self.transcriber_thread = QThread()
|
||||
|
||||
self.model_loader = ModelLoader(model=self.transcription_options.model)
|
||||
|
||||
if self.transcription_options.model.is_whisper_cpp():
|
||||
self.file_transcriber = WhisperCppFileTranscriber(
|
||||
file_transcription_options=self.file_transcription_options,
|
||||
transcription_options=self.transcription_options)
|
||||
else:
|
||||
self.file_transcriber = WhisperFileTranscriber(file_transcription_options=self.file_transcription_options,
|
||||
transcription_options=self.transcription_options)
|
||||
|
||||
self.model_loader.moveToThread(self.transcriber_thread)
|
||||
self.file_transcriber.moveToThread(self.transcriber_thread)
|
||||
|
||||
self.transcriber_thread.started.connect(self.model_loader.run)
|
||||
self.model_loader.finished.connect(
|
||||
self.transcriber_thread.quit)
|
||||
|
||||
self.model_loader.progress.connect(
|
||||
self.on_download_model_progress)
|
||||
self.model_loader.progress.connect(self.on_download_model_progress)
|
||||
|
||||
self.model_loader.error.connect(self.on_download_model_error)
|
||||
self.model_loader.error.connect(
|
||||
self.model_loader.deleteLater)
|
||||
self.model_loader.error.connect(
|
||||
self.file_transcriber.deleteLater)
|
||||
self.model_loader.error.connect(self.transcriber_thread.quit)
|
||||
self.model_loader.error.connect(self.model_loader.deleteLater)
|
||||
|
||||
self.model_loader.finished.connect(self.on_model_loaded)
|
||||
self.model_loader.finished.connect(self.model_loader.deleteLater)
|
||||
|
||||
# Run the file transcriber after the model loads
|
||||
self.model_loader.finished.connect(self.on_model_loaded)
|
||||
self.model_loader.finished.connect(self.file_transcriber.run)
|
||||
|
||||
self.file_transcriber.progress.connect(
|
||||
self.on_transcriber_progress)
|
||||
|
||||
self.file_transcriber.completed.connect(self.on_transcriber_complete)
|
||||
self.file_transcriber.completed.connect(self.transcriber_thread.quit)
|
||||
self.transcriber_thread.finished.connect(
|
||||
self.transcriber_thread.deleteLater)
|
||||
|
||||
self.transcriber_thread.start()
|
||||
|
||||
def on_model_loaded(self):
|
||||
self.is_transcribing = True
|
||||
def on_model_loaded(self, model_path: str):
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
self.triggered.emit((self.transcription_options,
|
||||
self.file_transcription_options, model_path))
|
||||
self.close()
|
||||
|
||||
def on_download_model_progress(self, progress: Tuple[int, int]):
|
||||
(current_size, total_size) = progress
|
||||
|
@ -421,44 +413,6 @@ class FileTranscriberWidget(QWidget):
|
|||
show_model_download_error_dialog(self, error)
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
def on_transcriber_progress(self, progress: Tuple[int, int]):
|
||||
(current_size, total_size) = progress
|
||||
|
||||
if self.is_transcribing:
|
||||
# Create a dialog
|
||||
if self.transcriber_progress_dialog is None:
|
||||
self.transcriber_progress_dialog = TranscriberProgressDialog(
|
||||
file_path=self.file_path, total_size=total_size, parent=self)
|
||||
self.transcriber_progress_dialog.canceled.connect(
|
||||
self.on_cancel_transcriber_progress_dialog)
|
||||
else:
|
||||
# Update the progress of the dialog unless it has
|
||||
# been canceled before this progress update arrived
|
||||
self.transcriber_progress_dialog.update_progress(current_size)
|
||||
|
||||
@pyqtSlot(tuple)
|
||||
def on_transcriber_complete(self, result: Tuple[int, List[Segment]]):
|
||||
exit_code, segments = result
|
||||
|
||||
self.is_transcribing = False
|
||||
|
||||
if self.transcriber_progress_dialog is not None:
|
||||
self.transcriber_progress_dialog.reset()
|
||||
if exit_code != 0:
|
||||
self.transcriber_progress_dialog.close()
|
||||
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
TranscriptionViewerWidget(
|
||||
transcription_options=self.transcription_options,
|
||||
file_transcription_options=self.file_transcription_options,
|
||||
segments=segments, parent=self, flags=Qt.WindowType.Window).show()
|
||||
|
||||
def on_cancel_transcriber_progress_dialog(self):
|
||||
if self.file_transcriber is not None:
|
||||
self.file_transcriber.stop()
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
def reset_transcriber_controls(self):
|
||||
self.run_button.setDisabled(False)
|
||||
|
||||
|
@ -541,23 +495,6 @@ class TranscriptionViewerWidget(QWidget):
|
|||
should_open=True, output_format=output_format)
|
||||
|
||||
|
||||
class Settings(QSettings):
|
||||
def __init__(self, parent: Optional[QWidget] = None):
|
||||
super().__init__('Buzz', 'Buzz', parent)
|
||||
logging.debug('Loaded settings from path = %s', self.fileName())
|
||||
|
||||
# Convert QSettings value to boolean: https://forum.qt.io/topic/108622/how-to-get-a-boolean-value-from-qsettings-correctly
|
||||
@staticmethod
|
||||
def _value_to_bool(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return value.lower() == 'true'
|
||||
|
||||
return bool(value)
|
||||
|
||||
|
||||
class AdvancedSettingsButton(QPushButton):
|
||||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
super().__init__('Advanced...', parent)
|
||||
|
@ -572,11 +509,14 @@ class RecordingTranscriberWidget(QWidget):
|
|||
model_loader: Optional[ModelLoader] = None
|
||||
model_loader_thread: Optional[QThread] = None
|
||||
|
||||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
super().__init__(parent)
|
||||
def __init__(self, parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
|
||||
super().__init__(parent, flags)
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
self.setWindowTitle('Live Recording')
|
||||
self.setFixedSize(400, 520)
|
||||
|
||||
self.transcription_options = TranscriptionOptions()
|
||||
|
||||
self.audio_devices_combo_box = AudioDevicesComboBox(self)
|
||||
|
@ -594,10 +534,12 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
transcription_options_group_box = TranscriptionOptionsGroupBox(
|
||||
default_transcription_options=self.transcription_options, parent=self)
|
||||
transcription_options_group_box.transcription_options_changed.connect(self.on_transcription_options_changed)
|
||||
transcription_options_group_box.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
|
||||
recording_options_layout = QFormLayout()
|
||||
recording_options_layout.addRow('Microphone:', self.audio_devices_combo_box)
|
||||
recording_options_layout.addRow(
|
||||
'Microphone:', self.audio_devices_combo_box)
|
||||
|
||||
record_button_layout = QHBoxLayout()
|
||||
record_button_layout.addStretch()
|
||||
|
@ -611,6 +553,10 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
self.setLayout(layout)
|
||||
|
||||
def closeEvent(self, event: QCloseEvent) -> None:
|
||||
self.stop_recording()
|
||||
return super().closeEvent(event)
|
||||
|
||||
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
|
||||
self.transcription_options = transcription_options
|
||||
|
||||
|
@ -626,7 +572,8 @@ class RecordingTranscriberWidget(QWidget):
|
|||
def start_recording(self):
|
||||
self.record_button.setDisabled(True)
|
||||
|
||||
use_whisper_cpp = self.transcription_options.model.is_whisper_cpp() and self.transcription_options.language is not None
|
||||
use_whisper_cpp = self.transcription_options.model.is_whisper_cpp(
|
||||
) and self.transcription_options.language is not None
|
||||
|
||||
def start_recording_transcription(model_path: str):
|
||||
# Clear text box placeholder because the first chunk takes a while to process
|
||||
|
@ -716,34 +663,33 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.model_download_progress_dialog = None
|
||||
|
||||
|
||||
ICON_PATH = '../assets/buzz.ico'
|
||||
ICON_LARGE_PATH = '../assets/buzz-icon-1024.png'
|
||||
|
||||
|
||||
def get_asset_path(path: str):
|
||||
base_dir = os.path.dirname(sys.executable if getattr(
|
||||
sys, 'frozen', False) else __file__)
|
||||
return os.path.join(base_dir, path)
|
||||
|
||||
|
||||
class AppIcon(QIcon):
|
||||
def __init__(self):
|
||||
super().__init__(get_asset_path(ICON_PATH))
|
||||
BUZZ_ICON_PATH = get_asset_path('../assets/buzz.ico')
|
||||
BUZZ_LARGE_ICON_PATH = get_asset_path('../assets/buzz-icon-1024.png')
|
||||
RECORD_ICON_PATH = get_asset_path('../assets/record-icon.svg')
|
||||
EXPAND_ICON_PATH = get_asset_path(
|
||||
'../assets/up-down-and-down-left-from-center-icon.svg')
|
||||
ADD_ICON_PATH = get_asset_path('../assets/circle-plus-icon.svg')
|
||||
|
||||
|
||||
class AboutDialog(QDialog):
|
||||
def __init__(self, parent: Optional[QWidget] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.setFixedSize(200, 200)
|
||||
self.setFixedSize(200, 250)
|
||||
|
||||
self.setWindowIcon(AppIcon())
|
||||
self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
|
||||
self.setWindowTitle(f'About {APP_NAME}')
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
image_label = QLabel()
|
||||
pixmap = QPixmap(get_asset_path(ICON_LARGE_PATH)).scaled(
|
||||
pixmap = QPixmap(BUZZ_LARGE_ICON_PATH).scaled(
|
||||
80, 80, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
|
||||
image_label.setPixmap(pixmap)
|
||||
image_label.setAlignment(Qt.AlignmentFlag(
|
||||
|
@ -761,80 +707,240 @@ class AboutDialog(QDialog):
|
|||
version_label.setAlignment(Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))
|
||||
|
||||
check_updates_button = QPushButton('Check for updates')
|
||||
check_updates_button = QPushButton('Check for updates', self)
|
||||
check_updates_button.clicked.connect(self.on_click_check_for_updates)
|
||||
|
||||
layout.addStretch(1)
|
||||
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
|
||||
QDialogButtonBox.StandardButton.Close), self)
|
||||
button_box.accepted.connect(self.accept)
|
||||
button_box.rejected.connect(self.reject)
|
||||
|
||||
layout.addWidget(image_label)
|
||||
layout.addWidget(buzz_label)
|
||||
layout.addWidget(version_label)
|
||||
layout.addWidget(check_updates_button)
|
||||
layout.addStretch(1)
|
||||
layout.addWidget(button_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def on_click_check_for_updates(self):
|
||||
response = get(
|
||||
'https://api.github.com/repos/chidiwilliams/buzz/releases/latest', timeout=15).json()
|
||||
version_number = response.get('name')
|
||||
version_number = response.field('name')
|
||||
if version_number == 'v' + VERSION:
|
||||
dialog = QMessageBox(self)
|
||||
dialog.setText("You're up to date!")
|
||||
dialog.exec()
|
||||
dialog.open()
|
||||
else:
|
||||
QDesktopServices.openUrl(
|
||||
QUrl('https://github.com/chidiwilliams/buzz/releases/latest'))
|
||||
|
||||
|
||||
class TranscriptionTasksTableWidget(QTableWidget):
|
||||
TASK_ID_COLUMN_INDEX = 0
|
||||
FILE_NAME_COLUMN_INDEX = 1
|
||||
STATUS_COLUMN_INDEX = 2
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.setRowCount(0)
|
||||
self.setAlternatingRowColors(True)
|
||||
|
||||
self.setColumnCount(3)
|
||||
self.setColumnHidden(0, True)
|
||||
|
||||
self.verticalHeader().hide()
|
||||
self.setHorizontalHeaderLabels(['ID', 'File Name', 'Status'])
|
||||
self.horizontalHeader().setMinimumSectionSize(140)
|
||||
self.horizontalHeader().setSectionResizeMode(self.FILE_NAME_COLUMN_INDEX,
|
||||
QHeaderView.ResizeMode.Stretch)
|
||||
|
||||
self.setSelectionBehavior(
|
||||
QAbstractItemView.SelectionBehavior.SelectRows)
|
||||
self.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
|
||||
|
||||
def upsert_task(self, task: FileTranscriptionTask):
|
||||
task_row_index = self.task_row_index(task.id)
|
||||
if task_row_index is None:
|
||||
self.insertRow(self.rowCount())
|
||||
|
||||
row_index = self.rowCount() - 1
|
||||
task_id_widget_item = QTableWidgetItem(str(task.id))
|
||||
self.setItem(row_index, self.TASK_ID_COLUMN_INDEX,
|
||||
task_id_widget_item)
|
||||
|
||||
file_name_widget_item = QTableWidgetItem(os.path.basename(
|
||||
task.file_transcription_options.file_path))
|
||||
file_name_widget_item.setFlags(
|
||||
file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||||
self.setItem(row_index, self.FILE_NAME_COLUMN_INDEX,
|
||||
file_name_widget_item)
|
||||
|
||||
status_widget_item = QTableWidgetItem(
|
||||
task.status.value.title() if task.status is not None else '')
|
||||
status_widget_item.setFlags(
|
||||
status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||||
self.setItem(row_index, self.STATUS_COLUMN_INDEX,
|
||||
status_widget_item)
|
||||
else:
|
||||
status_widget = self.item(task_row_index, self.STATUS_COLUMN_INDEX)
|
||||
|
||||
if task.status == FileTranscriptionTask.Status.IN_PROGRESS:
|
||||
status_widget.setText(
|
||||
f'In Progress ({task.fraction_completed :.0%})')
|
||||
elif task.status == FileTranscriptionTask.Status.COMPLETED:
|
||||
status_widget.setText('Completed')
|
||||
elif task.status == FileTranscriptionTask.Status.ERROR:
|
||||
status_widget.setText('Failed')
|
||||
|
||||
def task_row_index(self, task_id: int) -> int | None:
|
||||
table_items_matching_task_id = [item for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly) if
|
||||
item.column() == self.TASK_ID_COLUMN_INDEX]
|
||||
if len(table_items_matching_task_id) == 0:
|
||||
return None
|
||||
return table_items_matching_task_id[0].row()
|
||||
|
||||
@staticmethod
|
||||
def find_task_id(index: QModelIndex):
|
||||
return int(index.siblingAtColumn(TranscriptionTasksTableWidget.TASK_ID_COLUMN_INDEX).data())
|
||||
|
||||
|
||||
class MainWindow(QMainWindow):
|
||||
new_import_window_triggered = pyqtSignal(tuple)
|
||||
table_widget: TranscriptionTasksTableWidget
|
||||
next_task_id = 0
|
||||
tasks: Dict[int, 'FileTranscriptionTask']
|
||||
|
||||
def __init__(self, title: str, w: int, h: int, parent: Optional[QWidget], *args):
|
||||
super().__init__(parent, *args)
|
||||
def __init__(self):
|
||||
super().__init__(flags=Qt.WindowType.Window)
|
||||
|
||||
self.setFixedSize(w, h)
|
||||
self.setWindowTitle(f'{title} - {APP_NAME}')
|
||||
self.setWindowIcon(AppIcon())
|
||||
self.setWindowTitle(APP_NAME)
|
||||
self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
|
||||
self.setFixedSize(400, 400)
|
||||
|
||||
import_audio_file_action = QAction("&Import Audio File...", self)
|
||||
import_audio_file_action.triggered.connect(
|
||||
self.on_import_audio_file_action)
|
||||
import_audio_file_action.setShortcut(QKeySequence.fromString('Ctrl+O'))
|
||||
self.tasks = {}
|
||||
|
||||
menu = self.menuBar()
|
||||
record_action = QAction(QIcon(RECORD_ICON_PATH), 'Record', self)
|
||||
record_action.triggered.connect(self.on_record_action_triggered)
|
||||
|
||||
self.file_menu = menu.addMenu("&File")
|
||||
self.file_menu.addAction(import_audio_file_action)
|
||||
new_transcription_action = QAction(
|
||||
QIcon(ADD_ICON_PATH), 'New Transcription', self)
|
||||
new_transcription_action.triggered.connect(
|
||||
self.on_new_transcription_action_triggered)
|
||||
|
||||
self.about_action = QAction(f'&About {APP_NAME}', self)
|
||||
self.about_action.triggered.connect(self.on_trigger_about_action)
|
||||
self.open_transcript_action = QAction(QIcon(EXPAND_ICON_PATH),
|
||||
'Open Transcript', self)
|
||||
self.open_transcript_action.triggered.connect(
|
||||
self.on_open_transcript_action_triggered)
|
||||
self.open_transcript_action.setDisabled(True)
|
||||
|
||||
self.help_menu = menu.addMenu("&Help")
|
||||
self.help_menu.addAction(self.about_action)
|
||||
toolbar = QToolBar()
|
||||
toolbar.addAction(record_action)
|
||||
toolbar.addSeparator()
|
||||
toolbar.addAction(new_transcription_action)
|
||||
toolbar.addAction(self.open_transcript_action)
|
||||
toolbar.setMovable(False)
|
||||
toolbar.setIconSize(QSize(16, 16))
|
||||
toolbar.setContentsMargins(0, 2, 0, 2)
|
||||
|
||||
def on_import_audio_file_action(self):
|
||||
# Fix spacing issue on Mac
|
||||
if platform.system() == 'Darwin':
|
||||
toolbar.widgetForAction(toolbar.actions()[0]).setStyleSheet(
|
||||
'QToolButton { margin-left: 9px; margin-right: 1px; }')
|
||||
|
||||
self.addToolBar(toolbar)
|
||||
self.setUnifiedTitleAndToolBarOnMac(True)
|
||||
|
||||
menu_bar = MenuBar(self)
|
||||
menu_bar.import_action_triggered.connect(
|
||||
self.on_new_transcription_action_triggered)
|
||||
self.setMenuBar(menu_bar)
|
||||
|
||||
self.table_widget = TranscriptionTasksTableWidget(self)
|
||||
self.table_widget.doubleClicked.connect(self.on_table_double_clicked)
|
||||
self.table_widget.itemSelectionChanged.connect(
|
||||
self.on_table_selection_changed)
|
||||
|
||||
self.setCentralWidget(self.table_widget)
|
||||
|
||||
# Start transcriber thread
|
||||
self.transcriber_thread = QThread()
|
||||
|
||||
self.transcriber_worker = FileTranscriberQueueWorker()
|
||||
self.transcriber_worker.moveToThread(self.transcriber_thread)
|
||||
|
||||
self.transcriber_worker.task_updated.connect(self.on_task_updated)
|
||||
self.transcriber_worker.completed.connect(self.transcriber_thread.quit)
|
||||
|
||||
self.transcriber_thread.started.connect(self.transcriber_worker.run)
|
||||
self.transcriber_thread.finished.connect(
|
||||
self.transcriber_thread.deleteLater)
|
||||
self.transcriber_thread.finished.connect(
|
||||
lambda: print('thread closed'))
|
||||
|
||||
self.transcriber_thread.start()
|
||||
|
||||
def on_file_transcriber_triggered(self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions, str]):
|
||||
transcription_options, file_transcription_options, model_path = options
|
||||
task = FileTranscriptionTask(
|
||||
self.next_task_id, transcription_options, file_transcription_options, model_path)
|
||||
|
||||
self.transcriber_worker.add_task(task)
|
||||
|
||||
self.next_task_id += 1
|
||||
|
||||
def on_task_updated(self, task: FileTranscriptionTask):
|
||||
self.table_widget.upsert_task(task)
|
||||
self.tasks[task.id] = task
|
||||
|
||||
def on_record_action_triggered(self):
|
||||
recording_transcriber_window = RecordingTranscriberWidget(
|
||||
self, flags=Qt.WindowType.Window)
|
||||
recording_transcriber_window.show()
|
||||
|
||||
def on_new_transcription_action_triggered(self):
|
||||
(file_path, _) = QFileDialog.getOpenFileName(
|
||||
self, 'Select audio file', '', SUPPORTED_OUTPUT_FORMATS)
|
||||
if file_path == '':
|
||||
return
|
||||
self.new_import_window_triggered.emit((file_path, self.geometry()))
|
||||
|
||||
def on_trigger_about_action(self):
|
||||
about_dialog = AboutDialog(self)
|
||||
about_dialog.exec()
|
||||
file_transcriber_window = FileTranscriberWidget(
|
||||
file_path, self, flags=Qt.WindowType.Window)
|
||||
file_transcriber_window.triggered.connect(
|
||||
self.on_file_transcriber_triggered)
|
||||
file_transcriber_window.show()
|
||||
|
||||
def on_open_transcript_action_triggered(self):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
if len(selected_rows) == 0:
|
||||
return
|
||||
task_id = TranscriptionTasksTableWidget.find_task_id(selected_rows[0])
|
||||
self.open_transcription_viewer(task_id)
|
||||
|
||||
class RecordingTranscriberMainWindow(MainWindow):
|
||||
def __init__(self, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(title='Live Recording', w=400, h=520, parent=parent, *args)
|
||||
def on_table_selection_changed(self):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
self.open_transcript_action.setDisabled(len(selected_rows) == 0)
|
||||
|
||||
self.central_widget = RecordingTranscriberWidget(self)
|
||||
self.central_widget.setContentsMargins(10, 10, 10, 10)
|
||||
self.setCentralWidget(self.central_widget)
|
||||
def on_table_double_clicked(self, index: QModelIndex):
|
||||
task_id = TranscriptionTasksTableWidget.find_task_id(index)
|
||||
self.open_transcription_viewer(task_id)
|
||||
|
||||
def closeEvent(self, event: QCloseEvent) -> None:
|
||||
self.central_widget.stop_recording()
|
||||
return super().closeEvent(event)
|
||||
def open_transcription_viewer(self, task_id: int):
|
||||
task = self.tasks[task_id]
|
||||
if task.segments is None:
|
||||
return
|
||||
|
||||
transcription_viewer_widget = TranscriptionViewerWidget(
|
||||
file_transcription_options=task.file_transcription_options,
|
||||
transcription_options=task.transcription_options, segments=task.segments,
|
||||
parent=self, flags=Qt.WindowType.Window)
|
||||
transcription_viewer_widget.show()
|
||||
|
||||
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
|
||||
self.transcriber_worker.stop()
|
||||
self.transcriber_thread.quit()
|
||||
self.transcriber_thread.wait()
|
||||
super().closeEvent(event)
|
||||
|
||||
|
||||
class TranscriptionOptionsGroupBox(QGroupBox):
|
||||
|
@ -855,7 +961,8 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.languages_combo_box = LanguagesComboBox(
|
||||
default_language=self.transcription_options.language,
|
||||
parent=self)
|
||||
self.languages_combo_box.languageChanged.connect(self.on_language_changed)
|
||||
self.languages_combo_box.languageChanged.connect(
|
||||
self.on_language_changed)
|
||||
|
||||
self.model_combo_box = ModelComboBox(
|
||||
default_model=self.transcription_options.model,
|
||||
|
@ -894,8 +1001,10 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def open_advanced_settings(self):
|
||||
dialog = AdvancedSettingsDialog(transcription_options=self.transcription_options, parent=self)
|
||||
dialog.transcription_options_changed.connect(self.on_transcription_options_changed)
|
||||
dialog = AdvancedSettingsDialog(
|
||||
transcription_options=self.transcription_options, parent=self)
|
||||
dialog.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
dialog.exec()
|
||||
|
||||
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
|
||||
|
@ -903,48 +1012,42 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.transcription_options_changed.emit(transcription_options)
|
||||
|
||||
|
||||
class FileTranscriberMainWindow(MainWindow):
|
||||
central_widget: FileTranscriberWidget
|
||||
class MenuBar(QMenuBar):
|
||||
import_action_triggered = pyqtSignal()
|
||||
|
||||
def __init__(self, file_path: str, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(title=get_short_file_path(
|
||||
file_path), w=400, h=270, parent=parent, *args)
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(parent)
|
||||
|
||||
self.central_widget = FileTranscriberWidget(file_path, self)
|
||||
self.central_widget.setContentsMargins(10, 10, 10, 10)
|
||||
self.setCentralWidget(self.central_widget)
|
||||
import_action = QAction("&Import Media File...", self)
|
||||
import_action.triggered.connect(
|
||||
self.on_import_action_triggered)
|
||||
import_action.setShortcut(QKeySequence.fromString('Ctrl+O'))
|
||||
|
||||
def closeEvent(self, event: QCloseEvent) -> None:
|
||||
self.central_widget.on_cancel_transcriber_progress_dialog()
|
||||
return super().closeEvent(event)
|
||||
about_action = QAction(f'&About {APP_NAME}', self)
|
||||
about_action.triggered.connect(self.on_about_action_triggered)
|
||||
|
||||
file_menu = self.addMenu("&File")
|
||||
file_menu.addAction(import_action)
|
||||
|
||||
help_menu = self.addMenu("&Help")
|
||||
help_menu.addAction(about_action)
|
||||
|
||||
def on_import_action_triggered(self):
|
||||
self.import_action_triggered.emit()
|
||||
|
||||
def on_about_action_triggered(self):
|
||||
about_dialog = AboutDialog(self)
|
||||
about_dialog.open()
|
||||
|
||||
|
||||
class Application(QApplication):
|
||||
windows: List[MainWindow] = []
|
||||
window: MainWindow
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(sys.argv)
|
||||
|
||||
window = RecordingTranscriberMainWindow(None)
|
||||
window.new_import_window_triggered.connect(self.open_import_window)
|
||||
window.show()
|
||||
|
||||
self.windows.append(window)
|
||||
|
||||
def open_import_window(self, window_config: Tuple[str, QRect]):
|
||||
(file_path, geometry) = window_config
|
||||
|
||||
window = FileTranscriberMainWindow(file_path, None)
|
||||
|
||||
# Set window to open at an offset from the calling sibling
|
||||
OFFSET = 35
|
||||
geometry = QRect(geometry.left() + OFFSET, geometry.top() + OFFSET,
|
||||
geometry.width(), geometry.height())
|
||||
window.setGeometry(geometry)
|
||||
self.windows.append(window)
|
||||
|
||||
window.new_import_window_triggered.connect(self.open_import_window)
|
||||
window.show()
|
||||
self.window = MainWindow()
|
||||
self.window.show()
|
||||
|
||||
|
||||
class AdvancedSettingsDialog(QDialog):
|
||||
|
@ -973,12 +1076,15 @@ class AdvancedSettingsDialog(QDialog):
|
|||
self.temperature_line_edit.textChanged.connect(
|
||||
self.on_temperature_changed)
|
||||
self.temperature_line_edit.setValidator(TemperatureValidator(self))
|
||||
self.temperature_line_edit.setDisabled(transcription_options.model.is_whisper_cpp())
|
||||
self.temperature_line_edit.setDisabled(
|
||||
transcription_options.model.is_whisper_cpp())
|
||||
|
||||
self.initial_prompt_text_edit = QPlainTextEdit(transcription_options.initial_prompt, self)
|
||||
self.initial_prompt_text_edit = QPlainTextEdit(
|
||||
transcription_options.initial_prompt, self)
|
||||
self.initial_prompt_text_edit.textChanged.connect(
|
||||
self.on_initial_prompt_changed)
|
||||
self.initial_prompt_text_edit.setDisabled(transcription_options.model.is_whisper_cpp())
|
||||
self.initial_prompt_text_edit.setDisabled(
|
||||
transcription_options.model.is_whisper_cpp())
|
||||
|
||||
layout.addRow('Temperature:', self.temperature_line_edit)
|
||||
layout.addRow('Initial Prompt:', self.initial_prompt_text_edit)
|
||||
|
|
|
@ -6,10 +6,10 @@ from typing import Optional
|
|||
|
||||
import requests
|
||||
import whisper
|
||||
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
|
||||
from platformdirs import user_cache_dir
|
||||
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot, QThread
|
||||
|
||||
from buzz.transcriber import TranscriptionOptions, Model
|
||||
from buzz.transcriber import Model
|
||||
|
||||
MODELS_SHA256 = {
|
||||
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
|
||||
|
|
|
@ -6,7 +6,6 @@ import logging
|
|||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import queue
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -22,7 +21,7 @@ import numpy as np
|
|||
import sounddevice
|
||||
import stable_whisper
|
||||
import whisper
|
||||
from PyQt6.QtCore import QObject, QProcess, pyqtSignal, pyqtSlot
|
||||
from PyQt6.QtCore import QObject, QProcess, pyqtSignal, pyqtSlot, QThread
|
||||
from sounddevice import PortAudioError
|
||||
|
||||
from .conn import pipe_stderr
|
||||
|
@ -243,11 +242,13 @@ class WhisperCppFileTranscriber(QObject):
|
|||
running = False
|
||||
|
||||
def __init__(self, transcription_options: TranscriptionOptions,
|
||||
file_transcription_options: FileTranscriptionOptions, parent: Optional['QObject'] = None) -> None:
|
||||
file_transcription_options: FileTranscriptionOptions, model_path: str,
|
||||
parent: Optional['QObject'] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.file_path = file_transcription_options.file_path
|
||||
self.language = transcription_options.language
|
||||
self.model_path = model_path
|
||||
self.task = transcription_options.task
|
||||
self.word_level_timings = transcription_options.word_level_timings
|
||||
self.segments = []
|
||||
|
@ -256,9 +257,10 @@ class WhisperCppFileTranscriber(QObject):
|
|||
self.process.readyReadStandardError.connect(self.read_std_err)
|
||||
self.process.readyReadStandardOutput.connect(self.read_std_out)
|
||||
|
||||
@pyqtSlot(str)
|
||||
def run(self, model_path: str):
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
self.running = True
|
||||
model_path = self.model_path
|
||||
|
||||
logging.debug(
|
||||
'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, '
|
||||
|
@ -358,7 +360,9 @@ class WhisperFileTranscriber(QObject):
|
|||
READ_LINE_THREAD_STOP_TOKEN = '--STOP--'
|
||||
|
||||
def __init__(self, transcription_options: TranscriptionOptions,
|
||||
file_transcription_options: FileTranscriptionOptions, parent: Optional['QObject'] = None) -> None:
|
||||
file_transcription_options: FileTranscriptionOptions,
|
||||
model_path: str,
|
||||
parent: Optional['QObject'] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.file_path = file_transcription_options.file_path
|
||||
|
@ -367,16 +371,19 @@ class WhisperFileTranscriber(QObject):
|
|||
self.word_level_timings = transcription_options.word_level_timings
|
||||
self.temperature = transcription_options.temperature
|
||||
self.initial_prompt = transcription_options.initial_prompt
|
||||
self.model_path = model_path
|
||||
self.segments = []
|
||||
|
||||
@pyqtSlot(str)
|
||||
def run(self, model_path: str):
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
self.running = True
|
||||
model_path = self.model_path
|
||||
time_started = datetime.datetime.now()
|
||||
logging.debug(
|
||||
'Starting whisper file transcription, file path = %s, language = %s, task = %s, model path = %s, '
|
||||
'temperature = %s, initial prompt length = %s, word level timings = %s',
|
||||
self.file_path, self.language, self.task, model_path, self.temperature, len(self.initial_prompt),
|
||||
self.file_path, self.language, self.task, model_path, self.temperature, len(
|
||||
self.initial_prompt),
|
||||
self.word_level_timings)
|
||||
|
||||
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
|
@ -390,11 +397,15 @@ class WhisperFileTranscriber(QObject):
|
|||
))
|
||||
self.current_process.start()
|
||||
|
||||
self.read_line_thread = Thread(target=self.read_line, args=(recv_pipe,))
|
||||
self.read_line_thread = Thread(
|
||||
target=self.read_line, args=(recv_pipe,))
|
||||
self.read_line_thread.start()
|
||||
|
||||
self.current_process.join()
|
||||
|
||||
send_pipe.close()
|
||||
recv_pipe.close()
|
||||
|
||||
logging.debug(
|
||||
'whisper process completed with code = %s, time taken = %s',
|
||||
self.current_process.exitcode, datetime.datetime.now() - time_started)
|
||||
|
@ -412,7 +423,11 @@ class WhisperFileTranscriber(QObject):
|
|||
|
||||
def read_line(self, pipe: Connection):
|
||||
while True:
|
||||
line = pipe.recv().strip()
|
||||
try:
|
||||
line = pipe.recv().strip()
|
||||
except EOFError: # Connection closed
|
||||
break
|
||||
|
||||
if line == self.READ_LINE_THREAD_STOP_TOKEN:
|
||||
return
|
||||
|
||||
|
@ -464,7 +479,8 @@ def transcribe_whisper(
|
|||
segments_json = json.dumps(
|
||||
segments, ensure_ascii=True, default=vars)
|
||||
sys.stderr.write(f'segments = {segments_json}\n')
|
||||
sys.stderr.write(WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + '\n')
|
||||
sys.stderr.write(
|
||||
WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + '\n')
|
||||
|
||||
|
||||
def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat):
|
||||
|
@ -581,3 +597,117 @@ class WhisperCpp:
|
|||
|
||||
def __del__(self):
|
||||
whisper_cpp.whisper_free((self.ctx))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileTranscriptionTask:
|
||||
class Status(enum.Enum):
|
||||
QUEUED = 'queued'
|
||||
IN_PROGRESS = 'in_progress'
|
||||
COMPLETED = 'completed'
|
||||
ERROR = 'error'
|
||||
|
||||
id: int
|
||||
transcription_options: TranscriptionOptions
|
||||
file_transcription_options: FileTranscriptionOptions
|
||||
model_path: str
|
||||
segments: Optional[List[Segment]] = None
|
||||
status: Optional[Status] = None
|
||||
fraction_completed = 0.0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class FileTranscriberQueueWorker(QObject):
|
||||
queue: multiprocessing.Queue
|
||||
current_task: Optional[FileTranscriptionTask] = None
|
||||
current_transcriber: Optional[WhisperFileTranscriber |
|
||||
WhisperCppFileTranscriber] = None
|
||||
current_transcriber_thread: Optional[QThread] = None
|
||||
task_updated = pyqtSignal(FileTranscriptionTask)
|
||||
completed = pyqtSignal()
|
||||
|
||||
QUEUE_STOP_SIGNAL = None
|
||||
|
||||
def __init__(self, parent: Optional[QObject] = None):
|
||||
super().__init__(parent)
|
||||
self.queue = multiprocessing.Queue()
|
||||
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
logging.debug('Waiting for next file transcription task')
|
||||
self.current_task = self.queue.get()
|
||||
if self.current_task is self.QUEUE_STOP_SIGNAL:
|
||||
self.completed.emit()
|
||||
return
|
||||
|
||||
if self.current_task.transcription_options.model.is_whisper_cpp():
|
||||
self.current_transcriber = WhisperCppFileTranscriber(
|
||||
transcription_options=self.current_task.transcription_options,
|
||||
file_transcription_options=self.current_task.file_transcription_options,
|
||||
model_path=self.current_task.model_path,
|
||||
)
|
||||
else:
|
||||
self.current_transcriber = WhisperFileTranscriber(
|
||||
transcription_options=self.current_task.transcription_options,
|
||||
file_transcription_options=self.current_task.file_transcription_options,
|
||||
model_path=self.current_task.model_path,
|
||||
)
|
||||
|
||||
self.current_transcriber_thread = QThread(self)
|
||||
|
||||
self.current_transcriber.moveToThread(self.current_transcriber_thread)
|
||||
|
||||
self.current_transcriber_thread.started.connect(
|
||||
self.current_transcriber.run)
|
||||
self.current_transcriber.completed.connect(
|
||||
self.current_transcriber_thread.quit)
|
||||
|
||||
self.current_transcriber.completed.connect(
|
||||
self.current_transcriber.deleteLater)
|
||||
self.current_transcriber_thread.finished.connect(
|
||||
self.current_transcriber_thread.deleteLater)
|
||||
|
||||
self.current_transcriber.progress.connect(self.on_task_progress)
|
||||
self.current_transcriber.error.connect(self.on_task_error)
|
||||
|
||||
self.current_transcriber.completed.connect(self.on_task_completed)
|
||||
|
||||
# Wait for next item on the queue
|
||||
self.current_transcriber.completed.connect(self.run)
|
||||
|
||||
self.current_transcriber_thread.start()
|
||||
|
||||
def add_task(self, task: FileTranscriptionTask):
|
||||
self.queue.put(task)
|
||||
task.status = FileTranscriptionTask.Status.QUEUED
|
||||
self.task_updated.emit(task)
|
||||
|
||||
@pyqtSlot(str)
|
||||
def on_task_error(self, error: str):
|
||||
if self.current_task is not None:
|
||||
self.current_task.status = FileTranscriptionTask.Status.ERROR
|
||||
self.current_task.error = error
|
||||
self.task_updated.emit(self.current_task)
|
||||
|
||||
@pyqtSlot(tuple)
|
||||
def on_task_progress(self, progress: Tuple[int, int]):
|
||||
if self.current_task is not None:
|
||||
self.current_task.status = FileTranscriptionTask.Status.IN_PROGRESS
|
||||
self.current_task.fraction_completed = progress[0] / progress[1]
|
||||
self.task_updated.emit(self.current_task)
|
||||
|
||||
@pyqtSlot(tuple)
|
||||
def on_task_completed(self, result: Tuple[int, List[Segment]]):
|
||||
if self.current_task is not None:
|
||||
_, segments = result
|
||||
self.current_task.status = FileTranscriptionTask.Status.COMPLETED
|
||||
self.current_task.segments = segments
|
||||
self.task_updated.emit(self.current_task)
|
||||
|
||||
def stop(self):
|
||||
self.queue.put(self.QUEUE_STOP_SIGNAL)
|
||||
if self.current_transcriber is not None:
|
||||
self.current_transcriber.stop()
|
||||
if self.current_transcriber_thread is not None:
|
||||
self.current_transcriber_thread.quit()
|
||||
self.current_transcriber_thread.wait()
|
||||
|
|
|
@ -13,31 +13,17 @@ from pytestqt.qtbot import QtBot
|
|||
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
|
||||
AudioDevicesComboBox, DownloadModelProgressDialog,
|
||||
FileTranscriberWidget, LanguagesComboBox, MainWindow,
|
||||
ModelComboBox, TemperatureValidator,
|
||||
TextDisplayBox, TranscriberProgressDialog, TranscriptionViewerWidget, AppIcon)
|
||||
from buzz.transcriber import FileTranscriptionOptions, Segment, TranscriptionOptions, Model
|
||||
ModelComboBox, RecordingTranscriberWidget, TemperatureValidator,
|
||||
TextDisplayBox, TranscriberProgressDialog, TranscriptionTasksTableWidget, TranscriptionViewerWidget,)
|
||||
from buzz.transcriber import FileTranscriptionOptions, FileTranscriptionTask, Segment, Task, TranscriptionOptions, Model
|
||||
|
||||
|
||||
class TestApplication:
|
||||
# FIXME: this seems to break the tests if not run??
|
||||
app = Application()
|
||||
|
||||
def test_should_show_window_title(self):
|
||||
assert len(self.app.windows) == 1
|
||||
assert self.app.windows[0].windowTitle() == 'Live Recording - Buzz'
|
||||
|
||||
def test_should_open_a_new_import_file_window(self):
|
||||
main_window = self.app.windows[0]
|
||||
import_file_action = main_window.file_menu.actions()[0]
|
||||
|
||||
assert import_file_action.text() == '&Import Audio File...'
|
||||
|
||||
with patch('PyQt6.QtWidgets.QFileDialog.getOpenFileName') as open_file_name_mock:
|
||||
open_file_name_mock.return_value = ('/a/b/c.mp3', '')
|
||||
import_file_action.trigger()
|
||||
assert len(self.app.windows) == 2
|
||||
|
||||
new_window = self.app.windows[1]
|
||||
assert new_window.windowTitle() == 'c.mp3 - Buzz'
|
||||
def test_should_open_application(self):
|
||||
assert self.app is not None
|
||||
|
||||
|
||||
class TestLanguagesComboBox:
|
||||
|
@ -182,9 +168,12 @@ class TestDownloadModelProgressDialog:
|
|||
|
||||
|
||||
class TestMainWindow:
|
||||
def test_should_init(self):
|
||||
main_window = MainWindow(title='', w=200, h=200, parent=None)
|
||||
assert main_window is not None
|
||||
window = MainWindow()
|
||||
|
||||
def test_should_set_window_title_and_icon(self, qtbot: QtBot):
|
||||
qtbot.add_widget(self.window)
|
||||
assert self.window.windowTitle() == 'Buzz'
|
||||
assert self.window.windowIcon().pixmap(QSize(64, 64)).isNull() is False
|
||||
|
||||
|
||||
def wait_until(callback: Callable[[], Any], timeout=0):
|
||||
|
@ -198,21 +187,31 @@ def wait_until(callback: Callable[[], Any], timeout=0):
|
|||
|
||||
|
||||
class TestFileTranscriberWidget:
|
||||
@pytest.mark.skip(reason='Waiting for signal crashes process on Windows and Mac')
|
||||
def test_should_transcribe(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_path='testdata/whisper-french.mp3', parent=None)
|
||||
|
||||
def test_should_set_window_title_and_size(self, qtbot: QtBot):
|
||||
qtbot.addWidget(self.widget)
|
||||
assert self.widget.windowTitle() == 'whisper-french.mp3'
|
||||
assert self.widget.size() == QSize(420, 270)
|
||||
|
||||
def test_should_emit_triggered_event(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_path='testdata/whisper-french.mp3', parent=None)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
# Waiting for a "transcribed" signal seems to work more consistently
|
||||
# than checking for the opening of a TranscriptionViewerWidget.
|
||||
# See also: https://github.com/pytest-dev/pytest-qt/issues/313
|
||||
with qtbot.wait_signal(widget.transcribed, timeout=30 * 1000):
|
||||
mock_triggered = Mock()
|
||||
widget.triggered.connect(mock_triggered)
|
||||
|
||||
with qtbot.wait_signal(widget.triggered, timeout=30 * 1000):
|
||||
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
|
||||
|
||||
transcription_viewer = widget.findChild(TranscriptionViewerWidget)
|
||||
assert isinstance(transcription_viewer, TranscriptionViewerWidget)
|
||||
assert len(transcription_viewer.segments) > 0
|
||||
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
|
||||
0][0]
|
||||
assert transcription_options.language is None
|
||||
assert transcription_options.model == Model.WHISPER_TINY
|
||||
assert file_transcription_options.file_path == 'testdata/whisper-french.mp3'
|
||||
assert len(model_path) > 0
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="transcription_started callback sometimes not getting called until all progress events are emitted")
|
||||
|
@ -254,7 +253,8 @@ class TestAdvancedSettingsDialog:
|
|||
qtbot.add_widget(dialog)
|
||||
|
||||
transcription_options_mock = Mock()
|
||||
dialog.transcription_options_changed.connect(transcription_options_mock)
|
||||
dialog.transcription_options_changed.connect(
|
||||
transcription_options_mock)
|
||||
|
||||
assert dialog.windowTitle() == 'Advanced Settings'
|
||||
assert dialog.temperature_line_edit.text() == '0.0, 0.8'
|
||||
|
@ -263,7 +263,8 @@ class TestAdvancedSettingsDialog:
|
|||
dialog.temperature_line_edit.setText('0.0, 0.8, 1.0')
|
||||
dialog.initial_prompt_text_edit.setPlainText('new prompt')
|
||||
|
||||
assert transcription_options_mock.call_args[0][0].temperature == (0.0, 0.8, 1.0)
|
||||
assert transcription_options_mock.call_args[0][0].temperature == (
|
||||
0.0, 0.8, 1.0)
|
||||
assert transcription_options_mock.call_args[0][0].initial_prompt == 'new prompt'
|
||||
|
||||
|
||||
|
@ -313,7 +314,34 @@ class TestTranscriptionViewerWidget:
|
|||
assert 'Bien venue dans' in output_file.read()
|
||||
|
||||
|
||||
class TestAppIcon:
|
||||
def test_loads(self):
|
||||
widget = AppIcon()
|
||||
assert widget.pixmap(QSize(64, 64)).isNull() is False
|
||||
class TestTranscriptionTasksTableWidget:
|
||||
widget = TranscriptionTasksTableWidget()
|
||||
|
||||
def test_upsert_task(self, qtbot: QtBot):
|
||||
qtbot.add_widget(self.widget)
|
||||
|
||||
task = FileTranscriptionTask(id=0, transcription_options=TranscriptionOptions(
|
||||
), file_transcription_options=FileTranscriptionOptions(file_path='testdata/whisper-french.mp3'), model_path='', status=FileTranscriptionTask.Status.QUEUED)
|
||||
|
||||
self.widget.upsert_task(task)
|
||||
|
||||
assert self.widget.rowCount() == 1
|
||||
assert self.widget.item(0, 1).text() == 'whisper-french.mp3'
|
||||
assert self.widget.item(0, 2).text() == 'Queued'
|
||||
|
||||
task.status = FileTranscriptionTask.Status.IN_PROGRESS
|
||||
task.fraction_completed = 0.3524
|
||||
self.widget.upsert_task(task)
|
||||
|
||||
assert self.widget.rowCount() == 1
|
||||
assert self.widget.item(0, 1).text() == 'whisper-french.mp3'
|
||||
assert self.widget.item(0, 2).text() == 'In Progress (35%)'
|
||||
|
||||
|
||||
class TestRecordingTranscriberWidget:
|
||||
widget = RecordingTranscriberWidget()
|
||||
|
||||
def test_should_set_window_title_and_size(self, qtbot: QtBot):
|
||||
qtbot.add_widget(self.widget)
|
||||
assert self.widget.windowTitle() == 'Live Recording'
|
||||
assert self.widget.size() == QSize(400, 520)
|
||||
|
|
|
@ -9,7 +9,7 @@ import pytest
|
|||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.model_loader import ModelLoader
|
||||
from buzz.transcriber import (FileTranscriptionOptions, OutputFormat, RecordingTranscriber, Segment, Task,
|
||||
from buzz.transcriber import (FileTranscriberQueueWorker, FileTranscriptionOptions, OutputFormat, RecordingTranscriber, Segment, Task,
|
||||
WhisperCpp, WhisperCppFileTranscriber,
|
||||
WhisperFileTranscriber,
|
||||
get_default_output_file_path, to_timestamp,
|
||||
|
@ -46,19 +46,20 @@ class TestWhisperCppFileTranscriber:
|
|||
(True, [Segment(30, 280, 'Bien'), Segment(280, 630, 'venue')])
|
||||
])
|
||||
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
|
||||
file_transcription_options = FileTranscriptionOptions(file_path='testdata/whisper-french.mp3')
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_path='testdata/whisper-french.mp3')
|
||||
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings)
|
||||
|
||||
model_path = get_model_path(Model.WHISPER_CPP_TINY)
|
||||
transcriber = WhisperCppFileTranscriber(
|
||||
file_transcription_options=file_transcription_options, transcription_options=transcription_options)
|
||||
file_transcription_options=file_transcription_options, transcription_options=transcription_options, model_path=model_path)
|
||||
mock_progress = Mock()
|
||||
mock_completed = Mock()
|
||||
transcriber.progress.connect(mock_progress)
|
||||
transcriber.completed.connect(mock_completed)
|
||||
with qtbot.waitSignal(transcriber.completed, timeout=10 * 60 * 1000):
|
||||
transcriber.run(model_path)
|
||||
transcriber.run()
|
||||
|
||||
mock_progress.assert_called()
|
||||
exit_code, segments = mock_completed.call_args[0][0]
|
||||
|
@ -96,14 +97,15 @@ class TestWhisperFileTranscriber:
|
|||
mock_completed = Mock()
|
||||
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings)
|
||||
file_transcription_options = FileTranscriptionOptions(file_path='testdata/whisper-french.mp3')
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_path='testdata/whisper-french.mp3')
|
||||
|
||||
transcriber = WhisperFileTranscriber(
|
||||
file_transcription_options=file_transcription_options, transcription_options=transcription_options)
|
||||
file_transcription_options=file_transcription_options, transcription_options=transcription_options, model_path=model_path)
|
||||
transcriber.progress.connect(mock_progress)
|
||||
transcriber.completed.connect(mock_completed)
|
||||
with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
|
||||
transcriber.run(model_path)
|
||||
transcriber.run()
|
||||
|
||||
# Reports progress at 0, 0<progress<100, and 100
|
||||
assert any(
|
||||
|
@ -127,12 +129,14 @@ class TestWhisperFileTranscriber:
|
|||
os.remove(output_file_path)
|
||||
|
||||
model_path = get_model_path(Model.WHISPER_TINY)
|
||||
file_transcription_options = FileTranscriptionOptions(file_path='testdata/whisper-french.mp3')
|
||||
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE, word_level_timings=False)
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_path='testdata/whisper-french.mp3')
|
||||
transcription_options = TranscriptionOptions(
|
||||
language='fr', task=Task.TRANSCRIBE, word_level_timings=False)
|
||||
|
||||
transcriber = WhisperFileTranscriber(
|
||||
file_transcription_options=file_transcription_options, transcription_options=transcription_options)
|
||||
transcriber.run(model_path)
|
||||
file_transcription_options=file_transcription_options, transcription_options=transcription_options, model_path=model_path)
|
||||
transcriber.run()
|
||||
time.sleep(1)
|
||||
transcriber.stop()
|
||||
|
||||
|
@ -164,8 +168,8 @@ class TestWhisperCpp:
|
|||
[
|
||||
(OutputFormat.TXT, 'Bien venue dans\n'),
|
||||
(
|
||||
OutputFormat.SRT,
|
||||
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
OutputFormat.SRT,
|
||||
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
(OutputFormat.VTT,
|
||||
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
])
|
||||
|
|
Loading…
Reference in a new issue