From fd49e21b0de189be7de43fefb4b185c80c8d22cb Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Sat, 22 Oct 2022 13:26:33 +0100 Subject: [PATCH] Add output format input box (#109) --- .github/workflows/ci.yml | 9 +++++---- Makefile | 7 +++---- gui.py | 39 +++++++++++++++++++++++++++++++++------ gui_test.py | 11 +++++++++-- transcriber.py | 11 ++++++----- transcriber_test.py | 17 ++++++++++++----- 6 files changed, 68 insertions(+), 26 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b0230088..aaccaa6c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,7 +41,7 @@ jobs: - run: poetry install - name: Test - run: poetry run make test + run: poetry run make libwhisper.so test build: runs-on: ${{ matrix.os }} strategy: @@ -51,15 +51,15 @@ jobs: CMD_BUILD: | brew install create-dmg brew install ffmpeg - poetry run make bundle_mac + poetry run make libwhisper.so bundle_mac - os: ubuntu-latest CMD_BUILD: | sudo apt update && sudo apt install ffmpeg - poetry run make bundle_linux + poetry run make libwhisper.so bundle_linux - os: windows-latest CMD_BUILD: | choco install ffmpeg - poetry run make bundle_windows + poetry run make libwhisper.so bundle_windows steps: - uses: actions/checkout@v3 with: @@ -90,6 +90,7 @@ jobs: name: Buzz-${{ runner.os }} path: | dist/Buzz*.tar.gz + dist/Buzz*.zip release: runs-on: ubuntu-latest needs: [build, test] diff --git a/Makefile b/Makefile index f4d9bee5..bbf76728 100644 --- a/Makefile +++ b/Makefile @@ -10,17 +10,16 @@ windows_zip_path := Buzz-${version}-windows.tar.gz buzz: make clean - make whisper_cpp + make libwhisper.so pyinstaller --noconfirm Buzz.spec clean: rm -rf dist/* || true test: - make whisper_cpp pytest --cov --cov-fail-under=54 --cov-report html -whisper_cpp: +libwhisper.so: gcc -O3 -std=c11 -pthread -mavx -mavx2 -mfma -mf16c -fPIC -c whisper.cpp/ggml.c -o whisper.cpp/ggml.o g++ -O3 -std=c++11 -pthread --shared -fPIC -static-libstdc++ whisper.cpp/whisper.cpp whisper.cpp/ggml.o -o libwhisper.so @@ -37,9 +36,9 @@ bundle_windows: # MAC - bundle_mac: make buzz + make zip_mac bundle_mac_local: make buzz diff --git a/gui.py b/gui.py index 9cb68e81..4a283ff6 100644 --- a/gui.py +++ b/gui.py @@ -19,7 +19,8 @@ from whisper import tokenizer import _whisper from _whisper import Task, WhisperCpp -from transcriber import FileTranscriber, RecordingTranscriber, State, Status +from transcriber import (FileTranscriber, OutputFormat, RecordingTranscriber, + State, Status) def get_platform_styles(all_platform_styles: Dict[str, str]): @@ -110,6 +111,21 @@ class TasksComboBox(QComboBox): self.taskChanged.emit(self.tasks[index]) +class OutputFormatsComboBox(QComboBox): + output_format_changed = pyqtSignal(OutputFormat) + formats: List[OutputFormat] + + def __init__(self, default_format: OutputFormat, parent: Optional[QWidget], *args) -> None: + super().__init__(parent, *args) + self.formats = [i for i in OutputFormat] + self.addItems(map(lambda format: format.value.upper(), self.formats)) + self.currentIndexChanged.connect(self.on_index_changed) + self.setCurrentText(default_format.value.title()) + + def on_index_changed(self, index: int): + self.output_format_changed.emit(self.formats[index]) + + class Quality(enum.Enum): LOW = 'low' MEDIUM = 'medium' @@ -305,6 +321,7 @@ class FileTranscriberWidget(QWidget): selected_quality = Quality.LOW selected_language: Optional[str] = None selected_task = Task.TRANSCRIBE + selected_output_format = OutputFormat.TXT model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None transcribe_progress = pyqtSignal(tuple) @@ -330,7 +347,7 @@ class FileTranscriberWidget(QWidget): self.on_language_changed) self.tasks_combo_box = TasksComboBox( - default_task=Task.TRANSCRIBE, + default_task=self.selected_task, parent=self) self.tasks_combo_box.taskChanged.connect(self.on_task_changed) @@ -338,12 +355,18 @@ class FileTranscriberWidget(QWidget): self.run_button.clicked.connect(self.on_click_run) self.run_button.setDefault(True) + output_formats_combo_box = OutputFormatsComboBox( + default_format=self.selected_output_format, parent=self) + output_formats_combo_box.output_format_changed.connect( + self.on_output_format_changed) + grid = ( ((0, 5, FormLabel('Task:', parent=self)), (5, 7, self.tasks_combo_box)), ((0, 5, FormLabel('Language:', parent=self)), (5, 7, self.languages_combo_box)), ((0, 5, FormLabel('Quality:', parent=self)), (5, 7, self.quality_combo_box)), + ((0, 5, FormLabel('Export As:', self)), (5, 7, output_formats_combo_box)), ((9, 3, self.run_button),) ) @@ -365,11 +388,15 @@ class FileTranscriberWidget(QWidget): def on_task_changed(self, task: Task): self.selected_task = task + def on_output_format_changed(self, format: OutputFormat): + self.selected_output_format = format + def on_click_run(self): default_path = FileTranscriber.get_default_output_file_path( - task=self.selected_task, input_file_path=self.file_path) + task=self.selected_task, input_file_path=self.file_path, + output_format=self.selected_output_format) (output_file, _) = QFileDialog.getSaveFileName( - self, 'Save File', default_path, 'Text files (*.txt *.srt *.vtt)') + self, 'Save File', default_path, f'Text files (*.{self.selected_output_format})') if output_file == '': return @@ -395,7 +422,7 @@ class FileTranscriberWidget(QWidget): self.file_transcriber = FileTranscriber( model=model, file_path=self.file_path, language=self.selected_language, task=self.selected_task, - output_file_path=output_file, progress_callback=self.on_transcribe_model_progress) + output_file_path=output_file, output_format=self.selected_output_format, progress_callback=self.on_transcribe_model_progress) self.file_transcriber.start() def on_download_model_progress(self, current_size: int, total_size: int): @@ -676,7 +703,7 @@ class RecordingTranscriberMainWindow(MainWindow): class FileTranscriberMainWindow(MainWindow): def __init__(self, file_path: str, parent: Optional[QWidget], *args) -> None: super().__init__(title=get_short_file_path( - file_path), w=400, h=180, parent=parent, *args) + file_path), w=400, h=210, parent=parent, *args) central_widget = FileTranscriberWidget(file_path, self) central_widget.setContentsMargins(10, 10, 10, 10) diff --git a/gui_test.py b/gui_test.py index f9997ddf..9e041f2a 100644 --- a/gui_test.py +++ b/gui_test.py @@ -3,8 +3,9 @@ from unittest.mock import patch import sounddevice from gui import (Application, AudioDevicesComboBox, - DownloadModelProgressDialog, LanguagesComboBox, MainWindow, - TranscriberProgressDialog) + DownloadModelProgressDialog, OutputFormatsComboBox, + LanguagesComboBox, MainWindow, TranscriberProgressDialog) +from transcriber import OutputFormat class TestApplication: @@ -125,6 +126,12 @@ class TestDownloadModelProgressDialog: assert self.dialog.labelText().startswith( 'Downloading resources (10.00%') +class TestFormatsComboBox: + def test_should_have_items(self): + formats_combo_box = OutputFormatsComboBox(OutputFormat.TXT, None) + assert formats_combo_box.itemText(0) == 'TXT' + assert formats_combo_box.itemText(1) == 'SRT' + assert formats_combo_box.itemText(2) == 'VTT' class TestMainWindow: def test_should_init(self): diff --git a/transcriber.py b/transcriber.py index 3c7dfdb3..bd15321a 100644 --- a/transcriber.py +++ b/transcriber.py @@ -179,8 +179,8 @@ def capture_fd(fd: int): class OutputFormat(enum.Enum): TXT = 'txt' - VTT = 'vtt' SRT = 'srt' + VTT = 'vtt' def to_timestamp(ms: float) -> str: @@ -239,7 +239,8 @@ class FileTranscriber: def __init__( self, model: Union[whisper.Whisper, _whisper.WhisperCpp], language: Optional[str], - task: _whisper.Task, file_path: str, output_file_path: str, + task: _whisper.Task, file_path: str, + output_file_path: str, output_format: OutputFormat, progress_callback: Callable[[int, int], None] = lambda *_: None, open_file_on_complete=True) -> None: self.model = model @@ -251,7 +252,7 @@ class FileTranscriber: self.open_file_on_complete = open_file_on_complete _, extension = os.path.splitext(self.output_file_path) - self.output_format = OutputFormat(extension[1:]) + self.output_format = output_format def start(self): self.current_thread = Thread(target=self.transcribe) @@ -316,5 +317,5 @@ class FileTranscriber: return self.stopped @classmethod - def get_default_output_file_path(cls, task: _whisper.Task, input_file_path: str): - return f'{os.path.splitext(input_file_path)[0]} ({task.value.title()}d on {datetime.datetime.now():%d-%b-%Y %H-%M-%S}).txt' + def get_default_output_file_path(cls, task: _whisper.Task, input_file_path: str, output_format: OutputFormat): + return f'{os.path.splitext(input_file_path)[0]} ({task.value.title()}d on {datetime.datetime.now():%d-%b-%Y %H-%M-%S}).{output_format.value}' diff --git a/transcriber_test.py b/transcriber_test.py index a9d68283..7d6a08be 100644 --- a/transcriber_test.py +++ b/transcriber_test.py @@ -5,8 +5,8 @@ import tempfile import pytest from _whisper import Task, WhisperCpp -from transcriber import (FileTranscriber, RecordingTranscriber, Status, - to_timestamp) +from transcriber import (OutputFormat, FileTranscriber, RecordingTranscriber, + Status, to_timestamp) class TestRecordingTranscriber: @@ -21,8 +21,15 @@ class TestRecordingTranscriber: class TestFileTranscriber: def test_default_output_file(self): - assert FileTranscriber.get_default_output_file_path( - Task.TRANSLATE, '/a/b/c.txt').startswith('/a/b/c (Translated on ') + srt = FileTranscriber.get_default_output_file_path( + Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.TXT) + assert srt.startswith('/a/b/c (Translated on ') + assert srt.endswith('.txt') + + srt = FileTranscriber.get_default_output_file_path( + Task.TRANSLATE, '/a/b/c.mp4', OutputFormat.SRT) + assert srt.startswith('/a/b/c (Translated on ') + assert srt.endswith('.srt') @pytest.mark.skip(reason='test ggml model not working for') def test_transcribe_whisper_cpp(self): @@ -30,7 +37,7 @@ class TestFileTranscriber: transcriber = FileTranscriber( model=WhisperCpp('testdata/ggml-tiny.bin'), language='en', task=Task.TRANSCRIBE, file_path='testdata/whisper.m4a', - output_file_path=output_file_path, + output_file_path=output_file_path, output_format=OutputFormat.TXT, open_file_on_complete=False) transcriber.start() transcriber.join()