diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cf4c7041..9717b662 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,6 +20,8 @@ jobs: - os: windows-latest steps: - uses: actions/checkout@v3 + with: + submodules: recursive - uses: actions/setup-python@v4 with: python-version: '3.9.13' @@ -60,6 +62,8 @@ jobs: poetry run make bundle_windows steps: - uses: actions/checkout@v3 + with: + submodules: recursive - uses: actions/setup-python@v4 with: python-version: '3.9.13' diff --git a/.gitignore b/.gitignore index adaae384..ff18340b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ build/ .coverage .env htmlcov/ +*.so diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..fa83e220 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "whisper.cpp"] + path = whisper.cpp + url = https://github.com/ggerganov/whisper.cpp diff --git a/Buzz.spec b/Buzz.spec index 831a45c1..b5d062c0 100644 --- a/Buzz.spec +++ b/Buzz.spec @@ -17,6 +17,7 @@ datas += copy_metadata('filelock') datas += copy_metadata('numpy') datas += copy_metadata('tokenizers') datas += collect_data_files('whisper') +datas += [('libwhisper.so', '.')] def get_ffmpeg(): diff --git a/Makefile b/Makefile index 8c6c00af..f4d9bee5 100644 --- a/Makefile +++ b/Makefile @@ -10,13 +10,19 @@ windows_zip_path := Buzz-${version}-windows.tar.gz buzz: make clean + make whisper_cpp pyinstaller --noconfirm Buzz.spec clean: rm -rf dist/* || true test: - pytest --cov --cov-fail-under=57 --cov-report html + make whisper_cpp + pytest --cov --cov-fail-under=54 --cov-report html + +whisper_cpp: + gcc -O3 -std=c11 -pthread -mavx -mavx2 -mfma -mf16c -fPIC -c whisper.cpp/ggml.c -o whisper.cpp/ggml.o + g++ -O3 -std=c++11 -pthread --shared -fPIC -static-libstdc++ whisper.cpp/whisper.cpp whisper.cpp/ggml.o -o libwhisper.so version: poetry version ${version} diff --git a/_whisper.py b/_whisper.py index ba69d825..70c87469 100644 --- a/_whisper.py +++ b/_whisper.py @@ -1,13 +1,17 @@ - +import ctypes +import enum import hashlib +import logging import os +import pathlib import warnings -from typing import Callable, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import requests import torch import whisper +from appdirs import user_cache_dir from whisper import Whisper from whisper.audio import * from whisper.decoding import * @@ -19,18 +23,166 @@ class Stopped(Exception): pass +@dataclass +class Segment: + start: float + end: float + text: str + + +class Task(enum.Enum): + TRANSLATE = "translate" + TRANSCRIBE = "transcribe" + + +class WhisperFullParams(ctypes.Structure): + _fields_ = [ + ("strategy", ctypes.c_int), + ("n_threads", ctypes.c_int), + ("offset_ms", ctypes.c_int), + ("translate", ctypes.c_bool), + ("no_context", ctypes.c_bool), + ("print_special_tokens", ctypes.c_bool), + ("print_progress", ctypes.c_bool), + ("print_realtime", ctypes.c_bool), + ("print_timestamps", ctypes.c_bool), + ("language", ctypes.c_char_p), + ("greedy", ctypes.c_int * 1), + ] + + +whisper_cpp = ctypes.CDLL(str(pathlib.Path().absolute() / "libwhisper.so"), winmode=1) + +whisper_cpp.whisper_init.restype = ctypes.c_void_p +whisper_cpp.whisper_full_default_params.restype = WhisperFullParams +whisper_cpp.whisper_full_get_segment_text.restype = ctypes.c_char_p + + +def whisper_cpp_progress(lines: str) -> Optional[int]: + """Extracts the progress of a whisper.cpp transcription. + + The log lines have the following format: + whisper_full: progress = 20%\n + """ + + # Example log line: "whisper_full: progress = 20%" + progress_lines = list(filter(lambda line: line.startswith( + 'whisper_full: progress'), lines.split('\n'))) + if len(progress_lines) == 0: + return None + last_word = progress_lines[-1].split(' ')[-1] + return min(int(last_word[:-1]), 100) + + +def whisper_cpp_params(language: str, task: Task, print_realtime=False, print_progress=False): + params = whisper_cpp.whisper_full_default_params(0) + params.print_realtime = print_realtime + params.print_progress = print_progress + params.language = language.encode('utf-8') + params.translate = task == Task.TRANSLATE + return params + + +class WhisperCpp: + def __init__(self, model: str) -> None: + self.ctx = whisper_cpp.whisper_init(model.encode('utf-8')) + + def transcribe(self, audio: Union[np.ndarray, str], params: Any): + if isinstance(audio, str): + audio = whisper.audio.load_audio(audio) + + result = whisper_cpp.whisper_full(ctypes.c_void_p( + self.ctx), params, audio.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), len(audio)) + if result != 0: + raise Exception(f'Error from whisper.cpp: {result}') + + segments: List[Segment] = [] + + n_segments = whisper_cpp.whisper_full_n_segments( + ctypes.c_void_p(self.ctx)) + for i in range(n_segments): + txt = whisper_cpp.whisper_full_get_segment_text( + ctypes.c_void_p(self.ctx), i) + t0 = whisper_cpp.whisper_full_get_segment_t0( + ctypes.c_void_p(self.ctx), i) + t1 = whisper_cpp.whisper_full_get_segment_t1( + ctypes.c_void_p(self.ctx), i) + + segments.append( + Segment(start=t0*10, # centisecond to ms + end=t1*10, # centisecond to ms + text=txt.decode('utf-8'))) + + return { + 'segments': segments, + 'text': ''.join([segment.text for segment in segments])} + + def __del__(self): + whisper_cpp.whisper_free(ctypes.c_void_p(self.ctx)) + + +WHISPER_CPP_SHA256 = { + 'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21', + 'small': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe', + 'base': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b' +} + + class ModelLoader: stopped = False - def __init__(self, name: str, + def __init__(self, name: str, use_whisper_cpp=False, on_download_model_chunk: Callable[[int, int], None] = lambda *_: None) -> None: self.name = name self.on_download_model_chunk = on_download_model_chunk + self.use_whisper_cpp = use_whisper_cpp - def load(self): - return load_model( - name=self.name, is_stopped=self.is_stopped, - on_download_model_chunk=self.on_download_model_chunk) + def load(self) -> Union[Whisper, WhisperCpp]: + if self.use_whisper_cpp: + base_dir = user_cache_dir('Buzz') + model_path = os.path.join( + base_dir, f'ggml-model-whisper-{self.name}.bin') + + if os.path.exists(model_path) and not os.path.isfile(model_path): + raise RuntimeError( + f"{model_path} exists and is not a regular file") + + expected_sha256 = WHISPER_CPP_SHA256[self.name] + + if os.path.isfile(model_path): + model_bytes = open(model_path, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: + return WhisperCpp(model_path) + + logging.debug( + f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file") + + url = f'https://ggml.buzz.chidiwilliams.com/ggml-model-whisper-{self.name}.bin' + + with requests.get(url, stream=True) as source, open(model_path, 'wb') as output: + source.raise_for_status() + + current_size = 0 + total_size = int(source.headers.get('Content-Length', 0)) + for chunk in source.iter_content(chunk_size=DONWLOAD_CHUNK_SIZE): + if self.is_stopped(): + os.unlink(model_path) + raise Stopped + + output.write(chunk) + current_size += len(chunk) + self.on_download_model_chunk(current_size, total_size) + + model_bytes = open(model_path, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.") + + return WhisperCpp(model_path) + else: + return load_model( + name=self.name, is_stopped=self.is_stopped, + on_download_model_chunk=self.on_download_model_chunk) def stop(self): self.stopped = True diff --git a/gui.py b/gui.py index 7c797ccf..6cd7b52a 100644 --- a/gui.py +++ b/gui.py @@ -4,12 +4,13 @@ import os import platform import sys from datetime import datetime -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import humanize import sounddevice import whisper -from PyQt6.QtCore import QDateTime, QObject, QRect, Qt, QTimer, pyqtSignal +from PyQt6.QtCore import (QDateTime, QObject, QRect, QSettings, Qt, QTimer, + pyqtSignal) from PyQt6.QtGui import QAction, QKeySequence, QTextCursor from PyQt6.QtWidgets import (QApplication, QComboBox, QFileDialog, QGridLayout, QLabel, QMainWindow, QPlainTextEdit, @@ -17,8 +18,8 @@ from PyQt6.QtWidgets import (QApplication, QComboBox, QFileDialog, QGridLayout, from whisper import tokenizer import _whisper -from transcriber import (FileTranscriber, RecordingTranscriber, State, Status, - Task) +from _whisper import Task, WhisperCpp +from transcriber import FileTranscriber, RecordingTranscriber, State, Status def get_platform_styles(all_platform_styles: Dict[str, str]): @@ -243,16 +244,13 @@ class TranscriberWithSignal(QObject): status_changed = pyqtSignal(Status) def __init__( - self, model: whisper.Whisper, language: Optional[str], - task: Task, parent: Optional[QWidget], input_device_index: Optional[int], - *args, - ) -> None: + self, model: Union[whisper.Whisper, WhisperCpp], language: Optional[str], + task: Task, input_device_index: Optional[int], parent: Optional[QWidget], *args) -> None: super().__init__(parent, *args) self.transcriber = RecordingTranscriber( model=model, language=language, status_callback=self.on_next_status, task=task, - input_device_index=input_device_index, - ) + input_device_index=input_device_index) def start_recording(self): self.transcriber.start_recording() @@ -300,7 +298,7 @@ def get_model_name(quality: Quality, language: Optional[str]) -> str: Quality.LOW: ('tiny', 'tiny.en'), Quality.MEDIUM: ('base', 'base.en'), Quality.HIGH: ('small', 'small.en'), - }[quality][1 if language == 'en' else 0] + }[quality][0] class FileTranscriberWidget(QWidget): @@ -316,6 +314,8 @@ class FileTranscriberWidget(QWidget): layout = QGridLayout(self) + self.settings = Settings(self) + self.file_path = file_path self.quality_combo_box = QualityComboBox( @@ -369,18 +369,22 @@ class FileTranscriberWidget(QWidget): default_path = FileTranscriber.get_default_output_file_path( task=self.selected_task, input_file_path=self.file_path) (output_file, _) = QFileDialog.getSaveFileName( - self, 'Save File', default_path, 'Text files (*.txt)') + self, 'Save File', default_path, 'Text files (*.txt *.srt *.vtt)') if output_file == '': return + use_whisper_cpp = self.settings.enable_ggml_inference( + ) and self.selected_language != None + self.run_button.setDisabled(True) model_name = get_model_name( self.selected_quality, self.selected_language) logging.debug(f'Loading model: {model_name}') self.model_loader = _whisper.ModelLoader( - name=model_name, on_download_model_chunk=self.on_download_model_progress) + name=model_name, use_whisper_cpp=use_whisper_cpp, + on_download_model_chunk=self.on_download_model_progress) try: model = self.model_loader.load() @@ -451,6 +455,16 @@ class FileTranscriberWidget(QWidget): self.model_download_progress_dialog = None +class Settings(QSettings): + ENABLE_GGML_INFERENCE = 'enable_ggml_inference' + + def __init__(self, parent: Optional[QWidget], *args): + super().__init__('Buzz', 'Buzz', parent, *args) + + def enable_ggml_inference(self): + return self.value(self.ENABLE_GGML_INFERENCE, False) + + class RecordingTranscriberWidget(QWidget): current_status = RecordButton.Status.STOPPED selected_quality = Quality.LOW @@ -458,12 +472,15 @@ class RecordingTranscriberWidget(QWidget): selected_device_id: Optional[int] selected_task = Task.TRANSCRIBE model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None + settings: Settings def __init__(self, parent: Optional[QWidget]) -> None: super().__init__(parent) layout = QGridLayout(self) + self.settings = Settings(self) + self.quality_combo_box = QualityComboBox( default_quality=self.selected_quality, parent=self) @@ -541,12 +558,16 @@ class RecordingTranscriberWidget(QWidget): def start_recording(self): self.record_button.setDisabled(True) + use_whisper_cpp = self.settings.enable_ggml_inference( + ) and self.selected_language != None + model_name = get_model_name( self.selected_quality, self.selected_language) logging.debug(f'Loading model: {model_name}') self.model_loader = _whisper.ModelLoader( - name=model_name, on_download_model_chunk=self.on_download_model_progress) + name=model_name, use_whisper_cpp=use_whisper_cpp, + on_download_model_chunk=self.on_download_model_progress) try: model = self.model_loader.load() @@ -618,13 +639,29 @@ class MainWindow(QMainWindow): self.file_menu = menu.addMenu("&File") self.file_menu.addAction(import_audio_file_action) + self.settings = Settings(self) + + enable_ggml_inference_action = QAction( + '&Enable GGML Inference', self) + enable_ggml_inference_action.setCheckable(True) + enable_ggml_inference_action.setChecked( + bool(self.settings.enable_ggml_inference())) + enable_ggml_inference_action.triggered.connect( + self.on_toggle_enable_ggml_inference) + + self.settings_menu = menu.addMenu('&Settings') + self.settings_menu.addAction(enable_ggml_inference_action) + def on_import_audio_file_action(self): (file_path, _) = QFileDialog.getOpenFileName( - self, 'Select audio file', '', 'Audio Files (*.mp3 *.wav *.m4a *.ogg);;Video Files (*.mp4 *.webm)') + self, 'Select audio file', '', 'Audio Files (*.mp3 *.wav *.m4a *.ogg);;Video Files (*.mp4 *.webm *.ogm)') if file_path == '': return self.new_import_window_triggered.emit((file_path, self.geometry())) + def on_toggle_enable_ggml_inference(self, state: bool): + self.settings.setValue(Settings.ENABLE_GGML_INFERENCE, state) + class RecordingTranscriberMainWindow(MainWindow): def __init__(self, parent: Optional[QWidget], *args) -> None: diff --git a/gui_test.py b/gui_test.py index 1bf61b32..f9997ddf 100644 --- a/gui_test.py +++ b/gui_test.py @@ -3,7 +3,7 @@ from unittest.mock import patch import sounddevice from gui import (Application, AudioDevicesComboBox, - DownloadModelProgressDialog, LanguagesComboBox, + DownloadModelProgressDialog, LanguagesComboBox, MainWindow, TranscriberProgressDialog) @@ -124,3 +124,9 @@ class TestDownloadModelProgressDialog: self.dialog.setValue(123456) assert self.dialog.labelText().startswith( 'Downloading resources (10.00%') + + +class TestMainWindow: + def test_should_init(self): + main_window = MainWindow(title='', w=200, h=200, parent=None) + assert main_window != None diff --git a/poetry.lock b/poetry.lock index c107a8e0..7b38f051 100644 --- a/poetry.lock +++ b/poetry.lock @@ -192,11 +192,11 @@ altgraph = ">=0.17" [[package]] name = "more-itertools" -version = "8.14.0" +version = "9.0.0" description = "More routines for operating on iterables, beyond itertools" category = "main" optional = false -python-versions = ">=3.5" +python-versions = ">=3.7" [[package]] name = "numpy" @@ -641,7 +641,7 @@ dev = ["pytest"] type = "git" url = "https://github.com/openai/whisper.git" reference = "HEAD" -resolved_reference = "d18e9ea5dd2ca57c697e8e55f9e654f06ede25d0" +resolved_reference = "7f3e408e092e73d472036ae2e3fba1e7c68ca4e6" [metadata] lock-version = "1.1" @@ -827,8 +827,8 @@ macholib = [ {file = "macholib-1.16.2.tar.gz", hash = "sha256:557bbfa1bb255c20e9abafe7ed6cd8046b48d9525db2f9b77d3122a63a2a8bf8"}, ] more-itertools = [ - {file = "more-itertools-8.14.0.tar.gz", hash = "sha256:c09443cd3d5438b8dafccd867a6bc1cb0894389e90cb53d227456b0b0bccb750"}, - {file = "more_itertools-8.14.0-py3-none-any.whl", hash = "sha256:1bc4f91ee5b1b31ac7ceacc17c09befe6a40a503907baf9c839c229b5095cfd2"}, + {file = "more-itertools-9.0.0.tar.gz", hash = "sha256:5a6257e40878ef0520b1803990e3e22303a41b5714006c32a3fd8304b26ea1ab"}, + {file = "more_itertools-9.0.0-py3-none-any.whl", hash = "sha256:250e83d7e81d0c87ca6bd942e6aeab8cc9daa6096d12c5308f3f92fa5e5c1f41"}, ] numpy = [ {file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"}, diff --git a/testdata/ggml-tiny.bin b/testdata/ggml-tiny.bin new file mode 100644 index 00000000..1351aeba Binary files /dev/null and b/testdata/ggml-tiny.bin differ diff --git a/testdata/whisper.m4a b/testdata/whisper.m4a new file mode 100644 index 00000000..46a68ad8 Binary files /dev/null and b/testdata/whisper.m4a differ diff --git a/transcriber.py b/transcriber.py index 2cde8ded..3c7dfdb3 100644 --- a/transcriber.py +++ b/transcriber.py @@ -1,17 +1,22 @@ import datetime import enum import logging +import multiprocessing import os import platform +import select import subprocess -from threading import Lock, Thread -from typing import Callable, Optional +import threading +from contextlib import contextmanager +from threading import Thread +from typing import Any, Callable, List, Optional, Union import numpy as np import sounddevice import whisper import _whisper +from _whisper import Segment class State(enum.Enum): @@ -25,20 +30,16 @@ class Status: self.text = text -class Task(enum.Enum): - TRANSLATE = "translate" - TRANSCRIBE = "transcribe" - - class RecordingTranscriber: """Transcriber records audio from a system microphone and transcribes it into text using Whisper.""" current_thread: Optional[Thread] current_stream: Optional[sounddevice.InputStream] is_running = False + MAX_QUEUE_SIZE = 10 - def __init__(self, model: whisper.Whisper, language: Optional[str], - status_callback: Callable[[Status], None], task: Task, + def __init__(self, model: Union[whisper.Whisper, _whisper.WhisperCpp], language: Optional[str], + status_callback: Callable[[Status], None], task: _whisper.Task, input_device_index: Optional[int] = None) -> None: self.model = model self.current_stream = None @@ -52,7 +53,7 @@ class RecordingTranscriber: # pause queueing if more than 3 batches behind self.max_queue_size = 3 * self.n_batch_samples self.queue = np.ndarray([], dtype=np.float32) - self.mutex = Lock() + self.mutex = threading.Lock() self.text = '' def start_recording(self): @@ -74,27 +75,35 @@ class RecordingTranscriber: while self.is_running: self.mutex.acquire() if self.queue.size >= self.n_batch_samples: - batch = self.queue[:self.n_batch_samples] + samples = self.queue[:self.n_batch_samples] self.queue = self.queue[self.n_batch_samples:] self.mutex.release() logging.debug( - f'Processing next frame, samples = {batch.size}, total samples = {self.queue.size}, amplitude = {self.amplitude(batch)}') + f'Processing next frame, samples = {samples.size}, total samples = {self.queue.size}, amplitude = {self.amplitude(samples)}') self.status_callback( Status(State.STARTING_NEXT_TRANSCRIPTION)) time_started = datetime.datetime.now() - result = self.model.transcribe( - audio=batch, language=self.language, task=self.task.value, - initial_prompt=self.text) # prompt model with text from previous transcriptions - batch_text: str = result.get('text') + if isinstance(self.model, whisper.Whisper): + result = self.model.transcribe( + audio=samples, language=self.language, task=self.task.value, + initial_prompt=self.text) # prompt model with text from previous transcriptions + else: + result = self.model.transcribe( + audio=samples, + params=_whisper.whisper_cpp_params( + language=self.language if self.language is not None else 'en', + task=self.task.value)) + + next_text: str = result.get('text') logging.debug( - f'Received next result, length = {len(batch_text)}, time taken = {datetime.datetime.now() - time_started}') + f'Received next result, length = {len(next_text)}, time taken = {datetime.datetime.now() - time_started}') self.status_callback( - Status(State.FINISHED_CURRENT_TRANSCRIPTION, batch_text)) + Status(State.FINISHED_CURRENT_TRANSCRIPTION, next_text)) - self.text += f'\n\n{batch_text}' + self.text += f'\n\n{next_text}' else: self.mutex.release() @@ -137,18 +146,112 @@ class RecordingTranscriber: logging.debug('Processing thread terminated') +def more_data(fd: int): + r, _, _ = select.select([fd], [], [], 0) + return bool(r) + + +def read_pipe_str(fd: int): + out = b'' + while more_data(fd): + out += os.read(fd, 1024) + return out.decode('utf-8') + + +@contextmanager +def capture_fd(fd: int): + """Captures and restores a file descriptor into a pipe + + Args: + fd (int): file descriptor + + Yields: + Tuple[int, int]: previous descriptor and pipe output + """ + pipe_out, pipe_in = os.pipe() + prev = os.dup(fd) + os.dup2(pipe_in, fd) + try: + yield (prev, pipe_out) + finally: + os.dup2(prev, fd) + + +class OutputFormat(enum.Enum): + TXT = 'txt' + VTT = 'vtt' + SRT = 'srt' + + +def to_timestamp(ms: float) -> str: + hr = int(ms / (1000*60*60)) + ms = ms - hr * (1000*60*60) + min = int(ms / (1000*60)) + ms = ms - min * (1000*60) + sec = int(ms / 1000) + ms = int(ms - sec * 1000) + return f'{hr:02d}:{min:02d}:{sec:02d}.{ms:03d}' + + +def write_output(path: str, segments: List[Segment], should_open: bool, output_format: OutputFormat): + file = open(path, 'w') + + if output_format == OutputFormat.TXT: + for segment in segments: + file.write(segment.text) + + elif output_format == OutputFormat.VTT: + file.write('WEBVTT\n\n') + for segment in segments: + file.write( + f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n') + file.write(f'{segment.text}\n\n') + + elif output_format == OutputFormat.SRT: + for (i, segment) in enumerate(segments): + file.write(f'{i+1}\n') + file.write( + f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n') + file.write(f'{segment.text}\n\n') + + file.close() + + if should_open: + try: + os.startfile(path) + except AttributeError: + opener = "open" if platform.system() == "Darwin" else "xdg-open" + subprocess.call([opener, path]) + + +def transcribe_cpp( + model: _whisper.WhisperCpp, audio: Union[np.ndarray, str], + params: Any, output_file_path: str, open_file_on_complete: bool, output_format): + result = model.transcribe(audio=audio, params=params) + write_output(output_file_path, result.get( + 'segments'), open_file_on_complete, output_format) + + class FileTranscriber: """FileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file using the default program for opening txt files.""" stopped = False - def __init__(self, model: whisper.Whisper, language: Optional[str], task: Task, file_path: str, output_file_path: str, progress_callback: Callable[[int, int], None]) -> None: + def __init__( + self, model: Union[whisper.Whisper, _whisper.WhisperCpp], language: Optional[str], + task: _whisper.Task, file_path: str, output_file_path: str, + progress_callback: Callable[[int, int], None] = lambda *_: None, + open_file_on_complete=True) -> None: self.model = model self.file_path = file_path self.output_file_path = output_file_path self.progress_callback = progress_callback self.language = language self.task = task + self.open_file_on_complete = open_file_on_complete + + _, extension = os.path.splitext(self.output_file_path) + self.output_format = OutputFormat(extension[1:]) def start(self): self.current_thread = Thread(target=self.transcribe) @@ -156,23 +259,55 @@ class FileTranscriber: def transcribe(self): try: - result = _whisper.transcribe( - model=self.model, audio=self.file_path, - progress_callback=self.progress_callback, - language=self.language, task=self.task.value, - check_stopped=self.check_stopped) + if isinstance(self.model, _whisper. WhisperCpp): + self.progress_callback(0, 100) + + with capture_fd(2) as (_, stderr): + process = multiprocessing.Process( + target=transcribe_cpp, + args=( + self.model, self.file_path, + _whisper.whisper_cpp_params( + language=self.language if self.language is not None else 'en', + task=self.task, print_realtime=True, print_progress=True), + self.output_file_path, + self.open_file_on_complete, + self.output_format)) + process.start() + + while process.is_alive(): + if self.check_stopped(): + process.kill() + + next_stderr = read_pipe_str(stderr) + if len(next_stderr) > 0: + progress = _whisper.whisper_cpp_progress( + next_stderr) + if progress != None: + self.progress_callback(progress, 100) + + self.progress_callback(100, 100) + else: + result = _whisper.transcribe( + model=self.model, audio=self.file_path, + progress_callback=self.progress_callback, + language=self.language, task=self.task.value, + check_stopped=self.check_stopped) + + segments = map( + lambda segment: Segment( + start=segment.get('start')*1000, # s to ms + end=segment.get('end')*1000, # s to ms + text=segment.get('text')), + result.get('segments')) + + write_output(self.output_file_path, list( + segments), self.open_file_on_complete, self.output_format) except _whisper.Stopped: return - output_file = open(self.output_file_path, 'w') - output_file.write(result.get('text')) - output_file.close() - - try: - os.startfile(self.output_file_path) - except AttributeError: - opener = "open" if platform.system() == "Darwin" else "xdg-open" - subprocess.call([opener, self.output_file_path]) + def join(self): + self.current_thread.join() def stop(self): self.stopped = True @@ -181,5 +316,5 @@ class FileTranscriber: return self.stopped @classmethod - def get_default_output_file_path(cls, task: Task, input_file_path: str): + def get_default_output_file_path(cls, task: _whisper.Task, input_file_path: str): return f'{os.path.splitext(input_file_path)[0]} ({task.value.title()}d on {datetime.datetime.now():%d-%b-%Y %H-%M-%S}).txt' diff --git a/transcriber_test.py b/transcriber_test.py index df4bc0e4..a9d68283 100644 --- a/transcriber_test.py +++ b/transcriber_test.py @@ -1,5 +1,12 @@ -from transcriber import FileTranscriber, RecordingTranscriber, Status, Task +import os +import tempfile + +import pytest + +from _whisper import Task, WhisperCpp +from transcriber import (FileTranscriber, RecordingTranscriber, Status, + to_timestamp) class TestRecordingTranscriber: @@ -16,3 +23,22 @@ class TestFileTranscriber: def test_default_output_file(self): assert FileTranscriber.get_default_output_file_path( Task.TRANSLATE, '/a/b/c.txt').startswith('/a/b/c (Translated on ') + + @pytest.mark.skip(reason='test ggml model not working for') + def test_transcribe_whisper_cpp(self): + output_file_path = os.path.join(tempfile.gettempdir(), 'whisper.txt') + transcriber = FileTranscriber( + model=WhisperCpp('testdata/ggml-tiny.bin'), language='en', + task=Task.TRANSCRIBE, file_path='testdata/whisper.m4a', + output_file_path=output_file_path, + open_file_on_complete=False) + transcriber.start() + transcriber.join() + + assert os.path.isfile(output_file_path) + + +class TestToTimestamp: + def test_to_timestamp(self): + assert to_timestamp(0) == '00:00:00.000' + assert to_timestamp(123456789) == '34:17:36.789' diff --git a/whisper.cpp b/whisper.cpp new file mode 160000 index 00000000..5698b517 --- /dev/null +++ b/whisper.cpp @@ -0,0 +1 @@ +Subproject commit 5698b51718c8588034a98c4c2651979af34f0e2e