Cache task queue on app exit (#260)

This commit is contained in:
Chidi Williams 2022-12-21 14:19:11 +00:00 committed by GitHub
parent 4747a5f655
commit 711a1b95be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 219 additions and 69 deletions

View file

@ -4,6 +4,7 @@ Transcribe and translate audio offline on your personal computer. Powered by Ope
![MIT License](https://img.shields.io/badge/license-MIT-green)
[![CI](https://github.com/chidiwilliams/buzz/actions/workflows/ci.yml/badge.svg)](https://github.com/chidiwilliams/buzz/actions/workflows/ci.yml)
[![codecov](https://codecov.io/github/chidiwilliams/buzz/branch/main/graph/badge.svg?token=YJSB8S2VEP)](https://codecov.io/github/chidiwilliams/buzz)
![GitHub release (latest by date)](https://img.shields.io/github/v/release/chidiwilliams/buzz)
![Buzz](./assets/buzz-banner.jpg)

4
assets/trash-icon.svg Normal file
View file

@ -0,0 +1,4 @@
<svg xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 448 512"><!--! Font Awesome Pro 6.2.1 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2022 Fonticons, Inc. -->
<path fill="#888" d="M135.2 17.7L128 32H32C14.3 32 0 46.3 0 64S14.3 96 32 96H416c17.7 0 32-14.3 32-32s-14.3-32-32-32H320l-7.2-14.3C307.4 6.8 296.3 0 284.2 0H163.8c-12.1 0-23.2 6.8-28.6 17.7zM416 128H32L53.2 467c1.6 25.3 22.6 45 47.9 45H346.9c25.3 0 46.3-19.7 47.9-45L416 128z"/>
</svg>

After

Width:  |  Height:  |  Size: 525 B

31
buzz/cache.py Normal file
View file

@ -0,0 +1,31 @@
import logging
import os
import pickle
from typing import List
from platformdirs import user_cache_dir
from .transcriber import FileTranscriptionTask
class TasksCache:
def __init__(self, cache_dir=user_cache_dir('Buzz')):
os.makedirs(cache_dir, exist_ok=True)
self.file_path = os.path.join(cache_dir, 'tasks')
def save(self, tasks: List[FileTranscriptionTask]):
with open(self.file_path, 'wb') as file:
pickle.dump(tasks, file)
def load(self) -> List[FileTranscriptionTask]:
try:
with open(self.file_path, 'rb') as file:
return pickle.load(file)
except FileNotFoundError:
return []
except pickle.UnpicklingError: # delete corrupted cache
os.remove(self.file_path)
return []
def clear(self):
os.remove(self.file_path)

View file

@ -1,15 +1,15 @@
import enum
import logging
import os
import platform
import random
import sys
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import humanize
import sounddevice
from PyQt6 import QtGui
from PyQt6.QtCore import (QDateTime, QObject, QSettings, Qt, QThread, pyqtSlot,
from PyQt6.QtCore import (QDateTime, QObject, Qt, QThread,
QTimer, QUrl, pyqtSignal, QModelIndex, QSize)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor, QValidator)
@ -22,10 +22,12 @@ from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
from requests import get
from whisper import tokenizer
from buzz.cache import TasksCache
from .__version__ import VERSION
from .model_loader import ModelLoader
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
RecordingTranscriber, Segment, Task,
RecordingTranscriber, Task,
WhisperCppFileTranscriber, WhisperFileTranscriber,
get_default_output_file_path, segments_to_text, write_output, TranscriptionOptions,
Model, FileTranscriberQueueWorker, FileTranscriptionTask)
@ -287,7 +289,7 @@ def show_model_download_error_dialog(parent: QWidget, error: str):
class FileTranscriberWidget(QWidget):
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
file_transcriber: Optional[Union[WhisperFileTranscriber,
WhisperCppFileTranscriber]] = None
WhisperCppFileTranscriber]] = None
model_loader: Optional[ModelLoader] = None
transcriber_thread: Optional[QThread] = None
file_transcription_options: FileTranscriptionOptions
@ -402,7 +404,8 @@ class TranscriptionViewerWidget(QWidget):
transcription_task: FileTranscriptionTask
def __init__(
self, transcription_task: FileTranscriptionTask, parent: Optional['QWidget'] = None, flags: Qt.WindowType = Qt.WindowType.Widget,
self, transcription_task: FileTranscriptionTask, parent: Optional['QWidget'] = None,
flags: Qt.WindowType = Qt.WindowType.Widget,
) -> None:
super().__init__(parent, flags)
self.transcription_task = transcription_task
@ -637,6 +640,7 @@ RECORD_ICON_PATH = get_asset_path('../assets/record-icon.svg')
EXPAND_ICON_PATH = get_asset_path(
'../assets/up-down-and-down-left-from-center-icon.svg')
ADD_ICON_PATH = get_asset_path('../assets/circle-plus-icon.svg')
TRASH_ICON_PATH = get_asset_path('../assets/trash-icon.svg')
class AboutDialog(QDialog):
@ -756,6 +760,10 @@ class TranscriptionTasksTableWidget(QTableWidget):
elif task.status == FileTranscriptionTask.Status.ERROR:
status_widget.setText('Failed')
def clear_task(self, task_id: int):
task_row_index = self.task_row_index(task_id)
self.removeRow(task_row_index)
def task_row_index(self, task_id: int) -> int | None:
table_items_matching_task_id = [item for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly) if
item.column() == self.TASK_ID_COLUMN_INDEX]
@ -768,49 +776,78 @@ class TranscriptionTasksTableWidget(QTableWidget):
return int(index.siblingAtColumn(TranscriptionTasksTableWidget.TASK_ID_COLUMN_INDEX).data())
class MainWindow(QMainWindow):
table_widget: TranscriptionTasksTableWidget
next_task_id = 0
tasks: Dict[int, 'FileTranscriptionTask']
class MainWindowToolbar(QToolBar):
new_transcription_action_triggered: pyqtSignal
open_transcript_action_triggered: pyqtSignal
clear_history_action_triggered: pyqtSignal
def __init__(self):
super().__init__(flags=Qt.WindowType.Window)
self.setWindowTitle(APP_NAME)
self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
self.setFixedSize(400, 400)
self.tasks = {}
def __init__(self, parent: Optional[QWidget]):
super().__init__(parent)
record_action = QAction(QIcon(RECORD_ICON_PATH), 'Record', self)
record_action.triggered.connect(self.on_record_action_triggered)
new_transcription_action = QAction(
QIcon(ADD_ICON_PATH), 'New Transcription', self)
new_transcription_action.triggered.connect(
self.on_new_transcription_action_triggered)
self.new_transcription_action_triggered = new_transcription_action.triggered
self.open_transcript_action = QAction(QIcon(EXPAND_ICON_PATH),
'Open Transcript', self)
self.open_transcript_action.triggered.connect(
self.on_open_transcript_action_triggered)
self.open_transcript_action_triggered = self.open_transcript_action.triggered
self.open_transcript_action.setDisabled(True)
toolbar = QToolBar()
toolbar.addAction(record_action)
toolbar.addSeparator()
toolbar.addAction(new_transcription_action)
toolbar.addAction(self.open_transcript_action)
toolbar.setMovable(False)
toolbar.setIconSize(QSize(16, 16))
toolbar.setContentsMargins(0, 2, 0, 2)
self.clear_history_action = QAction(QIcon(TRASH_ICON_PATH), 'Clear History', self)
self.clear_history_action_triggered = self.clear_history_action.triggered
self.clear_history_action.setDisabled(True)
self.addAction(record_action)
self.addSeparator()
self.addAction(new_transcription_action)
self.addAction(self.open_transcript_action)
self.addAction(self.clear_history_action)
self.setMovable(False)
self.setIconSize(QSize(16, 16))
self.setContentsMargins(0, 2, 0, 2)
# Fix spacing issue on Mac
if platform.system() == 'Darwin':
toolbar.widgetForAction(toolbar.actions()[0]).setStyleSheet(
self.widgetForAction(self.actions()[0]).setStyleSheet(
'QToolButton { margin-left: 9px; margin-right: 1px; }')
self.addToolBar(toolbar)
def on_record_action_triggered(self):
recording_transcriber_window = RecordingTranscriberWidget(
self, flags=Qt.WindowType.Window)
recording_transcriber_window.show()
def set_open_transcript_action_disabled(self, disabled: bool):
self.open_transcript_action.setDisabled(disabled)
def set_clear_history_action_enabled(self, enabled: bool):
self.clear_history_action.setEnabled(enabled)
class MainWindow(QMainWindow):
table_widget: TranscriptionTasksTableWidget
tasks: Dict[int, 'FileTranscriptionTask']
tasks_changed = pyqtSignal()
def __init__(self, tasks_cache=TasksCache()):
super().__init__(flags=Qt.WindowType.Window)
self.setWindowTitle(APP_NAME)
self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
self.setMinimumSize(400, 400)
self.tasks_cache = tasks_cache
self.tasks = {}
self.tasks_changed.connect(self.on_tasks_changed)
self.toolbar = MainWindowToolbar(self)
self.toolbar.new_transcription_action_triggered.connect(self.on_new_transcription_action_triggered)
self.toolbar.open_transcript_action_triggered.connect(self.on_open_transcript_action_triggered)
self.toolbar.clear_history_action_triggered.connect(self.on_clear_history_action_triggered)
self.addToolBar(self.toolbar)
self.setUnifiedTitleAndToolBarOnMac(True)
menu_bar = MenuBar(self)
@ -831,7 +868,8 @@ class MainWindow(QMainWindow):
self.transcriber_worker = FileTranscriberQueueWorker()
self.transcriber_worker.moveToThread(self.transcriber_thread)
self.transcriber_worker.task_updated.connect(self.on_task_updated)
self.transcriber_worker.task_updated.connect(
self.update_task_table_row)
self.transcriber_worker.completed.connect(self.transcriber_thread.quit)
self.transcriber_thread.started.connect(self.transcriber_worker.run)
@ -840,22 +878,35 @@ class MainWindow(QMainWindow):
self.transcriber_thread.start()
self.load_tasks_from_cache()
def on_file_transcriber_triggered(self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions, str]):
transcription_options, file_transcription_options, model_path = options
for file_path in file_transcription_options.file_paths:
task = FileTranscriptionTask(
file_path, transcription_options, file_transcription_options, model_path, id=self.next_task_id)
file_path, transcription_options, file_transcription_options, model_path, id=self.get_next_task_id())
self.transcriber_worker.add_task(task)
self.next_task_id += 1
def on_task_updated(self, task: FileTranscriptionTask):
@classmethod
def get_next_task_id(cls) -> int:
return random.randint(0, 1_000_000)
def update_task_table_row(self, task: FileTranscriptionTask):
self.table_widget.upsert_task(task)
self.tasks[task.id] = task
self.tasks_changed.emit()
def on_record_action_triggered(self):
recording_transcriber_window = RecordingTranscriberWidget(
self, flags=Qt.WindowType.Window)
recording_transcriber_window.show()
@staticmethod
def task_completed_or_errored(task: FileTranscriptionTask):
return task.status == FileTranscriptionTask.Status.COMPLETED or \
task.status == FileTranscriptionTask.Status.ERROR
def on_clear_history_action_triggered(self):
for task_id, task in list(self.tasks.items()):
if self.task_completed_or_errored(task):
self.table_widget.clear_task(task_id)
self.tasks.pop(task_id)
self.tasks_changed.emit()
def on_new_transcription_action_triggered(self):
(file_paths, _) = QFileDialog.getOpenFileNames(
@ -878,7 +929,7 @@ class MainWindow(QMainWindow):
def on_table_selection_changed(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
self.open_transcript_action.setDisabled(len(selected_rows) == 0)
self.toolbar.set_open_transcript_action_disabled(len(selected_rows) == 0)
def on_table_double_clicked(self, index: QModelIndex):
task_id = TranscriptionTasksTableWidget.find_task_id(index)
@ -893,10 +944,28 @@ class MainWindow(QMainWindow):
transcription_task=task, parent=self, flags=Qt.WindowType.Window)
transcription_viewer_widget.show()
def load_tasks_from_cache(self):
tasks = self.tasks_cache.load()
for task in tasks:
if task.status == FileTranscriptionTask.Status.QUEUED or \
task.status == FileTranscriptionTask.Status.IN_PROGRESS:
task.status = None
self.transcriber_worker.add_task(task)
else:
self.update_task_table_row(task)
def save_tasks_to_cache(self):
self.tasks_cache.save(list(self.tasks.values()))
def on_tasks_changed(self):
self.toolbar.set_clear_history_action_enabled(
any([self.task_completed_or_errored(task) for task in self.tasks.values()]))
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
self.transcriber_worker.stop()
self.transcriber_thread.quit()
self.transcriber_thread.wait()
self.save_tasks_to_cache()
super().closeEvent(event)

View file

@ -15,6 +15,7 @@ from dataclasses import dataclass, field
from multiprocessing.connection import Connection
from threading import Thread
from typing import Any, Callable, List, Optional, Tuple, Union
import typing
import ffmpeg
import numpy as np
@ -624,8 +625,6 @@ class FileTranscriberQueueWorker(QObject):
task_updated = pyqtSignal(FileTranscriptionTask)
completed = pyqtSignal()
QUEUE_STOP_SIGNAL = None
def __init__(self, parent: Optional[QObject] = None):
super().__init__(parent)
self.queue = multiprocessing.Queue()
@ -633,8 +632,8 @@ class FileTranscriberQueueWorker(QObject):
@pyqtSlot()
def run(self):
logging.debug('Waiting for next file transcription task')
self.current_task = self.queue.get()
if self.current_task is self.QUEUE_STOP_SIGNAL:
self.current_task: Optional[FileTranscriptionTask] = self.queue.get()
if self.current_task is None:
self.completed.emit()
return
@ -697,7 +696,7 @@ class FileTranscriberQueueWorker(QObject):
self.task_updated.emit(self.current_task)
def stop(self):
self.queue.put(self.QUEUE_STOP_SIGNAL)
self.queue.put(None)
if self.current_transcriber is not None:
self.current_transcriber.stop()
if self.current_transcriber_thread is not None:

16
tests/cache_test.py Normal file
View file

@ -0,0 +1,16 @@
from buzz.cache import TasksCache
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
TranscriptionOptions)
class TestTasksCache:
def test_should_save_and_load(self, tmp_path):
cache = TasksCache(cache_dir=str(tmp_path))
tasks = [FileTranscriptionTask(file_path='1.mp3', transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=['1.mp3']),
model_path=''),
FileTranscriptionTask(file_path='2.mp3', transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=['2.mp3']),
model_path='')]
cache.save(tasks)
assert cache.load() == tasks

View file

@ -1,21 +1,23 @@
import os
import os.path
import pathlib
from typing import Any, Callable
from unittest.mock import Mock, patch
import pytest
import sounddevice
from PyQt6.QtCore import Qt, QCoreApplication, QSize
from PyQt6.QtGui import (QValidator)
from PyQt6.QtWidgets import (QPushButton)
from PyQt6.QtCore import QSize, Qt
from PyQt6.QtGui import QValidator
from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget
from pytestqt.qtbot import QtBot
from buzz.cache import TasksCache
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
AudioDevicesComboBox, DownloadModelProgressDialog,
FileTranscriberWidget, LanguagesComboBox, MainWindow,
ModelComboBox, RecordingTranscriberWidget, TemperatureValidator,
TextDisplayBox, TranscriptionTasksTableWidget, TranscriptionViewerWidget,)
from buzz.transcriber import FileTranscriptionOptions, FileTranscriptionTask, Segment, Task, TranscriptionOptions, Model
ModelComboBox, RecordingTranscriberWidget,
TemperatureValidator, TextDisplayBox,
TranscriptionTasksTableWidget, TranscriptionViewerWidget)
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
Model, Segment, TranscriptionOptions)
class TestApplication:
@ -146,23 +148,48 @@ class TestDownloadModelProgressDialog:
assert dialog.windowModality() == Qt.WindowModality.ApplicationModal
@pytest.fixture
def tasks_cache(tmp_path):
cache = TasksCache(cache_dir=str(tmp_path))
yield cache
cache.clear()
def get_test_asset(filename: str):
return os.path.join(os.path.dirname(__file__), '../testdata/', filename)
class TestMainWindow:
window = MainWindow()
def test_should_set_window_title_and_icon(self, qtbot: QtBot):
qtbot.add_widget(self.window)
assert self.window.windowTitle() == 'Buzz'
assert self.window.windowIcon().pixmap(QSize(64, 64)).isNull() is False
def test_should_set_window_title_and_icon(self, qtbot):
window = MainWindow()
qtbot.add_widget(window)
assert window.windowTitle() == 'Buzz'
assert window.windowIcon().pixmap(QSize(64, 64)).isNull() is False
window.close()
def test_should_run_transcription_task(self, qtbot: QtBot, tasks_cache):
window = MainWindow(tasks_cache=tasks_cache)
qtbot.add_widget(window)
def wait_until(callback: Callable[[], Any], timeout=0):
while True:
try:
QCoreApplication.processEvents()
callback()
return
except AssertionError:
pass
toolbar: QToolBar = window.findChild(QToolBar)
new_transcription_action = [action for action in toolbar.actions() if action.text() == 'New Transcription'][0]
with patch('PyQt6.QtWidgets.QFileDialog.getOpenFileNames') as open_file_names_mock:
open_file_names_mock.return_value = ([get_test_asset('whisper-french.mp3')], '')
new_transcription_action.trigger()
file_transcriber_widget: FileTranscriberWidget = window.findChild(FileTranscriberWidget)
run_button: QPushButton = file_transcriber_widget.findChild(QPushButton)
run_button.click()
def check_task_completed():
table_widget: QTableWidget = window.findChild(QTableWidget)
assert table_widget.rowCount() == 1
assert table_widget.item(0, 1).text() == 'whisper-french.mp3'
assert table_widget.item(0, 2).text() == 'Completed'
qtbot.wait_until(check_task_completed, timeout=60 * 1000)
class TestFileTranscriberWidget:
@ -280,8 +307,11 @@ class TestTranscriptionTasksTableWidget:
def test_upsert_task(self, qtbot: QtBot):
qtbot.add_widget(self.widget)
task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3', transcription_options=TranscriptionOptions(
), file_transcription_options=FileTranscriptionOptions(file_paths=['testdata/whisper-french.mp3']), model_path='', status=FileTranscriptionTask.Status.QUEUED)
task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3',
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3']), model_path='',
status=FileTranscriptionTask.Status.QUEUED)
self.widget.upsert_task(task)