fix: cli producing blank filenames (#681)

This commit is contained in:
Chidi Williams 2024-02-24 12:10:45 +00:00 committed by GitHub
parent 397dadd7a2
commit dfac983f13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 51 additions and 77 deletions

View file

@ -106,7 +106,6 @@ zip_mac:
ditto -c -k --keepParent "${mac_app_path}" "${mac_zip_path}"
codesign_all_mac: dist/Buzz.app
codesign --force --options=runtime --sign "$$BUZZ_CODESIGN_IDENTITY" --timestamp dist/Buzz.app/Contents/Resources/ffmpeg
for i in $$(find dist/Buzz.app/Contents/Resources/torch/bin -name "*" -type f); \
do \
codesign --force --options=runtime --sign "$$BUZZ_CODESIGN_IDENTITY" --timestamp "$$i"; \

View file

@ -4,7 +4,6 @@ import typing
from PyQt6.QtCore import QCommandLineParser, QCommandLineOption
from buzz.widgets.application import Application
from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel
from buzz.store.keyring_store import KeyringStore
from buzz.transcriber.transcriber import (
@ -15,6 +14,7 @@ from buzz.transcriber.transcriber import (
LANGUAGES,
OutputFormat,
)
from buzz.widgets.application import Application
class CommandLineError(Exception):
@ -185,7 +185,8 @@ def parse(app: Application, parser: QCommandLineParser):
openai_access_token=openai_access_token,
)
file_transcription_options = FileTranscriptionOptions(
file_paths=file_paths, output_formats=output_formats
file_paths=file_paths,
output_formats=output_formats,
)
for file_path in file_paths:

View file

@ -62,3 +62,9 @@ class Settings:
def sync(self):
self.settings.sync()
def get_default_export_file_template(self) -> str:
return self.value(
Settings.Key.DEFAULT_EXPORT_FILE_NAME,
"{{ input_file_name }} ({{ task }}d on {{ date_time }})",
)

View file

@ -9,6 +9,7 @@ from dataclasses_json import dataclass_json, config, Exclude
from buzz.locale import _
from buzz.model_loader import TranscriptionModel
from buzz.settings.settings import Settings
DEFAULT_WHISPER_TEMPERATURE = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
@ -153,7 +154,6 @@ class FileTranscriptionOptions:
file_paths: Optional[List[str]] = None
url: Optional[str] = None
output_formats: Set["OutputFormat"] = field(default_factory=set)
default_output_file_name: str = ""
@dataclass_json
@ -235,13 +235,22 @@ SUPPORTED_AUDIO_FORMATS = "Audio files (*.mp3 *.wav *.m4a *.ogg);;\
Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)"
def get_output_file_path(task: FileTranscriptionTask, output_format: OutputFormat):
def get_output_file_path(
task: FileTranscriptionTask,
output_format: OutputFormat,
export_file_name_template: str | None = None,
):
input_file_name = os.path.splitext(os.path.basename(task.file_path))[0]
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")
export_file_name_template = (
export_file_name_template
if export_file_name_template is not None
else Settings().get_default_export_file_template()
)
output_file_name = (
task.file_transcription_options.default_output_file_name.replace(
"{{ input_file_name }}", input_file_name
)
export_file_name_template.replace("{{ input_file_name }}", input_file_name)
.replace("{{ task }}", task.transcription_options.task.value)
.replace("{{ language }}", task.transcription_options.language or "")
.replace("{{ model_type }}", task.transcription_options.model.model_type.value)

View file

@ -59,10 +59,6 @@ class MainWindow(QMainWindow):
self.shortcut_settings = ShortcutSettings(settings=self.settings)
self.shortcuts = self.shortcut_settings.load()
self.default_export_file_name = self.settings.value(
Settings.Key.DEFAULT_EXPORT_FILE_NAME,
"{{ input_file_name }} ({{ task }}d on {{ date_time }})",
)
self.tasks = {}
@ -85,7 +81,6 @@ class MainWindow(QMainWindow):
self.preferences = self.load_preferences(settings=self.settings)
self.menu_bar = MenuBar(
shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
preferences=self.preferences,
parent=self,
)
@ -99,9 +94,6 @@ class MainWindow(QMainWindow):
self.menu_bar.openai_api_key_changed.connect(
self.on_openai_access_token_changed
)
self.menu_bar.default_export_file_name_changed.connect(
self.default_export_file_name_changed
)
self.menu_bar.preferences_changed.connect(self.on_preferences_changed)
self.setMenuBar(self.menu_bar)
@ -132,7 +124,6 @@ class MainWindow(QMainWindow):
self.folder_watcher = TranscriptionTaskFolderWatcher(
tasks=self.tasks,
preferences=self.preferences.folder_watch,
default_export_file_name=self.default_export_file_name,
)
self.folder_watcher.task_found.connect(self.add_task)
self.folder_watcher.find_tasks()
@ -259,7 +250,6 @@ class MainWindow(QMainWindow):
file_transcriber_window = FileTranscriberWidget(
file_paths=file_paths,
url=url,
default_output_file_name=self.default_export_file_name,
parent=self,
flags=Qt.WindowType.Window,
)
@ -273,13 +263,6 @@ class MainWindow(QMainWindow):
def on_openai_access_token_changed(access_token: str):
KeyringStore().set_password(KeyringStore.Key.OPENAI_API_KEY, access_token)
def default_export_file_name_changed(self, default_export_file_name: str):
self.default_export_file_name = default_export_file_name
self.settings.set_value(
Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name
)
self.folder_watcher.default_export_file_name = default_export_file_name
def open_transcript_viewer(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
for selected_row in selected_rows:

View file

@ -20,21 +20,18 @@ class MenuBar(QMenuBar):
import_url_action_triggered = pyqtSignal()
shortcuts_changed = pyqtSignal(dict)
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)
preferences_changed = pyqtSignal(Preferences)
preferences_dialog: Optional[PreferencesDialog] = None
def __init__(
self,
shortcuts: Dict[str, str],
default_export_file_name: str,
preferences: Preferences,
parent: Optional[QWidget] = None,
):
super().__init__(parent)
self.shortcuts = shortcuts
self.default_export_file_name = default_export_file_name
self.preferences = preferences
self.import_action = QAction(_("Import File..."), self)
@ -70,15 +67,11 @@ class MenuBar(QMenuBar):
def on_preferences_action_triggered(self):
preferences_dialog = PreferencesDialog(
shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
preferences=self.preferences,
parent=self,
)
preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed)
preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed)
preferences_dialog.default_export_file_name_changed.connect(
self.default_export_file_name_changed
)
preferences_dialog.finished.connect(self.on_preferences_dialog_finished)
preferences_dialog.open()

View file

@ -4,6 +4,7 @@ from PyQt6.QtCore import QRunnable, QObject, pyqtSignal, QThreadPool
from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox
from openai import AuthenticationError, OpenAI
from buzz.settings.settings import Settings
from buzz.store.keyring_store import KeyringStore
from buzz.widgets.line_edit import LineEdit
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
@ -11,11 +12,9 @@ from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
class GeneralPreferencesWidget(QWidget):
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)
def __init__(
self,
default_export_file_name: str,
keyring_store=KeyringStore(),
parent: Optional[QWidget] = None,
):
@ -41,15 +40,22 @@ class GeneralPreferencesWidget(QWidget):
layout.addRow("OpenAI API key", self.openai_api_key_line_edit)
layout.addRow("", self.test_openai_api_key_button)
self.settings = Settings()
default_export_file_name = self.settings.get_default_export_file_template()
default_export_file_name_line_edit = LineEdit(default_export_file_name, self)
default_export_file_name_line_edit.textChanged.connect(
self.default_export_file_name_changed
self.on_default_export_file_name_changed
)
default_export_file_name_line_edit.setMinimumWidth(200)
layout.addRow("Default export file name", default_export_file_name_line_edit)
self.setLayout(layout)
def on_default_export_file_name_changed(self, text: str):
self.settings.set_value(Settings.Key.DEFAULT_EXPORT_FILE_NAME, text)
def update_test_openai_api_key_button(self):
self.test_openai_api_key_button.setEnabled(len(self.openai_api_key) > 0)

View file

@ -81,7 +81,6 @@ class FileTranscriptionPreferences:
openai_access_token: Optional[str],
file_paths: Optional[List[str]] = None,
url: Optional[str] = None,
default_output_file_name: str = "",
) -> Tuple[TranscriptionOptions, FileTranscriptionOptions]:
return (
TranscriptionOptions(
@ -97,6 +96,5 @@ class FileTranscriptionPreferences:
output_formats=self.output_formats,
file_paths=file_paths,
url=url,
default_output_file_name=default_output_file_name,
),
)

View file

@ -27,14 +27,12 @@ class PreferencesDialog(QDialog):
shortcuts_changed = pyqtSignal(dict)
openai_api_key_changed = pyqtSignal(str)
folder_watch_config_changed = pyqtSignal(FolderWatchPreferences)
default_export_file_name_changed = pyqtSignal(str)
preferences_changed = pyqtSignal(Preferences)
def __init__(
self,
# TODO: move shortcuts and default export file name into preferences
shortcuts: Dict[str, str],
default_export_file_name: str,
preferences: Preferences,
parent: Optional[QWidget] = None,
) -> None:
@ -47,13 +45,8 @@ class PreferencesDialog(QDialog):
layout = QVBoxLayout(self)
tab_widget = QTabWidget(self)
general_tab_widget = GeneralPreferencesWidget(
default_export_file_name=default_export_file_name, parent=self
)
general_tab_widget = GeneralPreferencesWidget(parent=self)
general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed)
general_tab_widget.default_export_file_name_changed.connect(
self.default_export_file_name_changed
)
tab_widget.addTab(general_tab_widget, _("General"))
models_tab_widget = ModelsPreferencesWidget(parent=self)

View file

@ -40,7 +40,6 @@ class FileTranscriberWidget(QWidget):
def __init__(
self,
default_output_file_name: str,
file_paths: Optional[List[str]] = None,
url: Optional[str] = None,
parent: Optional[QWidget] = None,
@ -66,7 +65,6 @@ class FileTranscriberWidget(QWidget):
openai_access_token=openai_access_token,
file_paths=self.file_paths,
url=url,
default_output_file_name=default_output_file_name,
)
layout = QVBoxLayout(self)

View file

@ -19,12 +19,10 @@ class TranscriptionTaskFolderWatcher(QFileSystemWatcher):
self,
tasks: Dict[int, FileTranscriptionTask],
preferences: FolderWatchPreferences,
default_export_file_name: str,
parent: QObject = None,
):
super().__init__(parent)
self.tasks = tasks
self.default_export_file_name = default_export_file_name
self.set_preferences(preferences)
self.directoryChanged.connect(self.find_tasks)
@ -58,7 +56,6 @@ class TranscriptionTaskFolderWatcher(QFileSystemWatcher):
file_transcription_options,
) = self.preferences.file_transcription_options.to_transcription_options(
openai_access_token=openai_access_token,
default_output_file_name=self.default_export_file_name,
file_paths=[file_path],
)
model_path = transcription_options.model.get_local_model_path()

5
tests/audio.py Normal file
View file

@ -0,0 +1,5 @@
import os.path
test_audio_path = os.path.join(
os.path.dirname(__file__), "../testdata/whisper-french.mp3"
)

View file

@ -1,3 +1,4 @@
import glob
import logging
import os
import platform
@ -22,6 +23,7 @@ from buzz.transcriber.transcriber import (
Segment,
)
from buzz.transcriber.whisper_file_transcriber import WhisperFileTranscriber
from tests.audio import test_audio_path
from tests.model_loader import get_model_path
UNSUPPORTED_ON_LINUX_REASON = "Whisper not supported on Linux"
@ -29,20 +31,18 @@ UNSUPPORTED_ON_LINUX_REASON = "Whisper not supported on Linux"
class TestWhisperFileTranscriber:
@pytest.mark.parametrize(
"file_path,output_format,expected_file_path,default_output_file_name",
"file_path,output_format,expected_file_path",
[
pytest.param(
"/a/b/c.mp4",
OutputFormat.SRT,
"/a/b/c-translate--Whisper-tiny.srt",
"{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}",
marks=pytest.mark.skipif(platform.system() == "Windows", reason=""),
),
pytest.param(
"C:\\a\\b\\c.mp4",
OutputFormat.SRT,
"C:\\a\\b\\c-translate--Whisper-tiny.srt",
"{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}",
marks=pytest.mark.skipif(platform.system() != "Windows", reason=""),
),
],
@ -52,18 +52,16 @@ class TestWhisperFileTranscriber:
file_path: str,
output_format: OutputFormat,
expected_file_path: str,
default_output_file_name: str,
):
file_path = get_output_file_path(
task=FileTranscriptionTask(
file_path=file_path,
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(
file_paths=[], default_output_file_name=default_output_file_name
),
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
model_path="",
),
output_format=output_format,
export_file_name_template="{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}",
)
assert file_path == expected_file_path
@ -85,17 +83,20 @@ class TestWhisperFileTranscriber:
def test_default_output_file_with_date(
self, file_path: str, expected_starts_with: str
):
export_file_name_template = (
"{{ input_file_name }} (Translated on {{ date_time }})"
)
srt = get_output_file_path(
task=FileTranscriptionTask(
file_path=file_path,
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(
file_paths=[],
default_output_file_name="{{ input_file_name }} (Translated on {{ date_time }})",
),
model_path="",
),
output_format=OutputFormat.TXT,
export_file_name_template=export_file_name_template,
)
assert srt.startswith(expected_starts_with)
@ -107,11 +108,11 @@ class TestWhisperFileTranscriber:
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(
file_paths=[],
default_output_file_name="{{ input_file_name }} (Translated on {{ date_time }})",
),
model_path="",
),
output_format=OutputFormat.SRT,
export_file_name_template=export_file_name_template,
)
assert srt.startswith(expected_starts_with)
assert srt.endswith(".srt")
@ -270,12 +271,11 @@ class TestWhisperFileTranscriber:
)
def test_transcribe_from_folder_watch_source(self, qtbot):
file_path = tempfile.mktemp(suffix=".mp3")
shutil.copy("testdata/whisper-french.mp3", file_path)
shutil.copy(test_audio_path, file_path)
file_transcription_options = FileTranscriptionOptions(
file_paths=[file_path],
output_formats={OutputFormat.TXT},
default_output_file_name="{{ input_file_name }}",
)
transcription_options = TranscriptionOptions()
model_path = get_model_path(transcription_options.model)
@ -298,12 +298,7 @@ class TestWhisperFileTranscriber:
assert os.path.isfile(
os.path.join(output_directory, os.path.basename(file_path))
)
assert os.path.isfile(
os.path.join(
output_directory,
os.path.splitext(os.path.basename(file_path))[0] + ".txt",
)
)
assert len(glob.glob("*.txt", root_dir=output_directory)) > 0
@pytest.mark.skip()
def test_transcribe_stop(self):

View file

@ -10,7 +10,6 @@ class TestFileTranscriberWidget:
def test_should_set_window_title(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=["testdata/whisper-french.mp3"],
default_output_file_name="",
)
qtbot.add_widget(widget)
assert widget.windowTitle() == "whisper-french.mp3"
@ -18,7 +17,6 @@ class TestFileTranscriberWidget:
def test_should_emit_triggered_event(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=["testdata/whisper-french.mp3"],
default_output_file_name="",
)
qtbot.add_widget(widget)

View file

@ -11,7 +11,6 @@ class TestMenuBar:
def test_open_preferences_dialog(self, qtbot):
menu_bar = MenuBar(
shortcuts=ShortcutSettings(Settings()).load(),
default_export_file_name="",
preferences=Preferences.load(QSettings()),
)
qtbot.add_widget(menu_bar)

View file

@ -10,9 +10,7 @@ from buzz.widgets.preferences_dialog.general_preferences_widget import (
class TestGeneralPreferencesWidget:
def test_should_disable_test_button_if_no_api_key(self, qtbot):
widget = GeneralPreferencesWidget(
keyring_store=self.get_keyring_store(""), default_export_file_name=""
)
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(""))
qtbot.add_widget(widget)
test_button = widget.findChild(QPushButton)
@ -30,7 +28,6 @@ class TestGeneralPreferencesWidget:
def test_should_test_openai_api_key(self, qtbot):
widget = GeneralPreferencesWidget(
keyring_store=self.get_keyring_store("wrong-api-key"),
default_export_file_name="",
)
qtbot.add_widget(widget)

View file

@ -10,7 +10,6 @@ class TestPreferencesDialog:
def test_create(self, qtbot: QtBot):
dialog = PreferencesDialog(
shortcuts={},
default_export_file_name="",
preferences=Preferences.load(QSettings()),
)
qtbot.add_widget(dialog)

View file

@ -48,7 +48,6 @@ class TestTranscriptionTaskFolderWatcher:
output_formats=set(),
),
),
default_export_file_name="",
)
shutil.copy("testdata/whisper-french.mp3", input_directory)
@ -89,7 +88,6 @@ class TestTranscriptionTaskFolderWatcher:
output_formats=set(),
),
),
default_export_file_name="",
)
# Ignored because already in tasks