diff --git a/.coveragerc b/.coveragerc index 29b81b1..6566e22 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,4 +7,4 @@ omit = directory = coverage/html [report] -fail_under = 73 +fail_under = 72 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6fb98e7..53b6a49 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,6 +25,7 @@ jobs: - uses: actions/checkout@v3 with: submodules: recursive + - uses: actions/setup-python@v4 with: python-version: '3.10.7' @@ -49,6 +50,7 @@ jobs: path: | ~/Library/Caches/Buzz ~/.cache/whisper + ~/.cache/huggingface ~/AppData/Local/Buzz/Buzz/Cache key: whisper-models-${{ runner.os }} @@ -68,7 +70,6 @@ jobs: with: flags: ${{ runner.os }} - build: runs-on: ${{ matrix.os }} strategy: @@ -169,7 +170,7 @@ jobs: include: - os: macos-latest - os: windows-latest - needs: [ build, test ] + needs: [build, test] if: startsWith(github.ref, 'refs/tags/') steps: - uses: actions/checkout@v3 diff --git a/buzz/gui.py b/buzz/gui.py index 7fde263..85bad11 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -1,4 +1,3 @@ -import dataclasses import enum import json import logging @@ -7,12 +6,13 @@ import platform import random import sys from datetime import datetime +from enum import auto from typing import Dict, List, Optional, Tuple, Union import humanize import sounddevice from PyQt6 import QtGui -from PyQt6.QtCore import (QDateTime, QObject, Qt, QThread, +from PyQt6.QtCore import (QObject, Qt, QThread, QTimer, QUrl, pyqtSignal, QModelIndex, QSize, QPoint, QUrlQuery, QMetaObject, QEvent) from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon, @@ -23,13 +23,14 @@ from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog, QMainWindow, QMessageBox, QPlainTextEdit, QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu, QWidget, QGroupBox, QToolBar, QTableWidget, QMenuBar, QFormLayout, QTableWidgetItem, - QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QToolButton) + QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QToolButton, QSizePolicy) from requests import get from whisper import tokenizer from buzz.cache import TasksCache from .__version__ import VERSION from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel +from .recording import RecordingAmplitudeListener from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat, Task, WhisperCppFileTranscriber, WhisperFileTranscriber, @@ -151,40 +152,18 @@ class TextDisplayBox(QPlainTextEdit): class RecordButton(QPushButton): - class Status(enum.Enum): - RECORDING = enum.auto() - STOPPED = enum.auto() + def __init__(self, parent: Optional[QWidget]) -> None: + super().__init__("Record", parent) + self.setDefault(True) + self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)) - current_status = Status.STOPPED - status_changed = pyqtSignal(Status) - - def __init__(self, parent: Optional[QWidget], *args) -> None: - super().__init__("Record", parent, *args) - self.clicked.connect(self.on_click_record) - self.status_changed.connect(self.on_status_changed) + def set_to_record(self): + self.setText('Record') self.setDefault(True) - def on_click_record(self): - current_status: RecordButton.Status - if self.current_status == self.Status.RECORDING: - current_status = self.Status.STOPPED - else: - current_status = self.Status.RECORDING - - self.status_changed.emit(current_status) - - # TODO: control the text and status from the caller - def on_status_changed(self, status: Status): - self.current_status = status - if status == self.Status.RECORDING: - self.setText('Stop') - self.setDefault(False) - else: - self.setText('Record') - self.setDefault(True) - - def force_stop(self): - self.on_status_changed(self.Status.STOPPED) + def set_to_stop(self): + self.setText('Stop') + self.setDefault(False) class DownloadModelProgressDialog(QProgressDialog): @@ -208,37 +187,6 @@ class DownloadModelProgressDialog(QProgressDialog): f'Downloading model ({fraction_completed :.0%}, {humanize.naturaldelta(time_left)} remaining)') -class TimerLabel(QLabel): - start_time: Optional[QDateTime] - - def __init__(self, parent: Optional[QWidget]): - super().__init__(parent) - - self.timer = QTimer(self) - self.timer.timeout.connect(self.on_next_interval) - self.on_next_interval(stopped=True) - self.setAlignment(Qt.AlignmentFlag( - Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight)) - - def start_timer(self): - self.timer.start(1000) - self.start_time = QDateTime.currentDateTimeUtc() - self.on_next_interval() - - def stop_timer(self): - self.timer.stop() - self.on_next_interval(stopped=True) - - def on_next_interval(self, stopped=False): - if stopped: - self.setText('--:--') - elif self.start_time != None: - seconds_passed = self.start_time.secsTo( - QDateTime.currentDateTimeUtc()) - self.setText('{0:02}:{1:02}'.format( - seconds_passed // 60, seconds_passed % 60)) - - def show_model_download_error_dialog(parent: QWidget, error: str): message = f"An error occurred while loading the Whisper model: {error}{'' if error.endswith('.') else '.'}" \ f'information. ' @@ -432,20 +380,84 @@ class AdvancedSettingsButton(QPushButton): super().__init__('Advanced...', parent) -class RecordingTranscriberWidget(QWidget): - current_status = RecordButton.Status.STOPPED +class AudioMeterWidget(QWidget): + current_amplitude: float + BAR_WIDTH = 2 + BAR_MARGIN = 1 + BAR_INACTIVE_COLOR: QColor + BAR_ACTIVE_COLOR: QColor + + # Factor by which the amplitude is scaled to make the changes more visible + DIFF_MULTIPLIER_FACTOR = 10 + SMOOTHING_FACTOR = 0.95 + + def __init__(self, parent: Optional[QWidget] = None): + super().__init__(parent) + self.setMinimumWidth(10) + self.setFixedHeight(16) + + # Extra padding to fix layout + self.PADDING_TOP = 3 + + self.current_amplitude = 0.0 + + self.MINIMUM_AMPLITUDE = 0.00005 # minimum amplitude to show the first bar + self.AMPLITUDE_SCALE_FACTOR = 15 # scale the amplitudes such that 1/AMPLITUDE_SCALE_FACTOR will show all bars + + if self.palette().window().color().black() > 127: + self.BAR_INACTIVE_COLOR = QColor('#555') + self.BAR_ACTIVE_COLOR = QColor('#999') + else: + self.BAR_INACTIVE_COLOR = QColor('#BBB') + self.BAR_ACTIVE_COLOR = QColor('#555') + + def paintEvent(self, event: QtGui.QPaintEvent) -> None: + painter = QPainter(self) + painter.setPen(Qt.PenStyle.NoPen) + + rect = self.rect() + center_x = rect.center().x() + num_bars_in_half = int((rect.width() / 2) / (self.BAR_MARGIN + self.BAR_WIDTH)) + for i in range(num_bars_in_half): + is_bar_active = ((self.current_amplitude - self.MINIMUM_AMPLITUDE) * self.AMPLITUDE_SCALE_FACTOR) > ( + i / num_bars_in_half) + painter.setBrush(self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR) + + # draw to left + painter.drawRect(center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)), rect.top() + self.PADDING_TOP, + self.BAR_WIDTH, + rect.height() - self.PADDING_TOP) + # draw to right + painter.drawRect(center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))), + rect.top() + self.PADDING_TOP, + self.BAR_WIDTH, rect.height() - self.PADDING_TOP) + + def update_amplitude(self, amplitude: float): + self.current_amplitude = max(amplitude, self.current_amplitude * self.SMOOTHING_FACTOR) + self.repaint() + + +class RecordingTranscriberWidget(QDialog): + current_status: 'RecordingStatus' transcription_options: TranscriptionOptions selected_device_id: Optional[int] model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None transcriber: Optional[RecordingTranscriber] = None model_loader: Optional[ModelLoader] = None transcription_thread: Optional[QThread] = None + recording_amplitude_listener: Optional[RecordingAmplitudeListener] = None + device_sample_rate: Optional[int] = None - def __init__(self, parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowType.Widget) -> None: - super().__init__(parent, flags) + class RecordingStatus(enum.Enum): + STOPPED = auto() + RECORDING = auto() + + def __init__(self, parent: Optional[QWidget] = None) -> None: + super().__init__(parent) layout = QVBoxLayout(self) + self.current_status = self.RecordingStatus.STOPPED self.setWindowTitle('Live Recording') self.transcription_options = TranscriptionOptions( @@ -457,10 +469,8 @@ class RecordingTranscriberWidget(QWidget): self.on_device_changed) self.selected_device_id = self.audio_devices_combo_box.get_default_device_id() - self.timer_label = TimerLabel(self) - self.record_button = RecordButton(self) - self.record_button.status_changed.connect(self.on_status_changed) + self.record_button.clicked.connect(self.on_record_button_clicked) self.text_box = TextDisplayBox(self) self.text_box.setPlaceholderText('Click Record to begin...') @@ -474,9 +484,10 @@ class RecordingTranscriberWidget(QWidget): recording_options_layout.addRow( 'Microphone:', self.audio_devices_combo_box) + self.audio_meter_widget = AudioMeterWidget(self) + record_button_layout = QHBoxLayout() - record_button_layout.addStretch() - record_button_layout.addWidget(self.timer_label) + record_button_layout.addWidget(self.audio_meter_widget) record_button_layout.addWidget(self.record_button) layout.addWidget(transcription_options_group_box) @@ -487,21 +498,42 @@ class RecordingTranscriberWidget(QWidget): self.setLayout(layout) self.setFixedSize(self.sizeHint()) - def closeEvent(self, event: QCloseEvent) -> None: - self.stop_recording() - return super().closeEvent(event) + self.reset_recording_amplitude_listener() def on_transcription_options_changed(self, transcription_options: TranscriptionOptions): self.transcription_options = transcription_options def on_device_changed(self, device_id: int): self.selected_device_id = device_id + self.reset_recording_amplitude_listener() - def on_status_changed(self, status: RecordButton.Status): - if status == RecordButton.Status.RECORDING: + def reset_recording_amplitude_listener(self): + if self.recording_amplitude_listener is not None: + self.recording_amplitude_listener.stop_recording() + self.recording_amplitude_listener.deleteLater() + + # Listening to audio will fail if there are no input devices + if self.selected_device_id is None or self.selected_device_id == -1: + return + + # Get the device sample rate before starting the listener as the PortAudio function + # fails if you try to get the device's settings while recording is in progress. + self.device_sample_rate = RecordingTranscriber.get_device_sample_rate(self.selected_device_id) + + self.recording_amplitude_listener = RecordingAmplitudeListener(input_device_index=self.selected_device_id, + parent=self) + self.recording_amplitude_listener.amplitude_changed.connect(self.on_recording_amplitude_changed) + self.recording_amplitude_listener.start_recording() + + def on_record_button_clicked(self): + if self.current_status == self.RecordingStatus.STOPPED: self.start_recording() - else: + self.current_status = self.RecordingStatus.RECORDING + self.record_button.set_to_stop() + else: # RecordingStatus.RECORDING self.stop_recording() + self.record_button.set_to_record() + self.current_status = self.RecordingStatus.STOPPED def start_recording(self): self.record_button.setDisabled(True) @@ -510,6 +542,7 @@ class RecordingTranscriberWidget(QWidget): self.model_loader = ModelLoader(model=self.transcription_options.model) self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id, + sample_rate=self.device_sample_rate, transcription_options=self.transcription_options) self.model_loader.moveToThread(self.transcription_thread) @@ -551,7 +584,7 @@ class RecordingTranscriberWidget(QWidget): self.reset_model_download() show_model_download_error_dialog(self, error) self.stop_recording() - self.record_button.force_stop() + self.record_button.set_to_stop() self.record_button.setDisabled(False) def on_next_transcription(self, text: str): @@ -568,7 +601,6 @@ class RecordingTranscriberWidget(QWidget): self.transcriber.stop_recording() # Disable record button until the transcription is actually stopped in the background self.record_button.setDisabled(True) - self.timer_label.stop_timer() def on_transcriber_finished(self): self.record_button.setEnabled(True) @@ -577,7 +609,7 @@ class RecordingTranscriberWidget(QWidget): if self.model_loader is not None: self.model_loader.stop() self.reset_model_download() - self.record_button.force_stop() + self.record_button.set_to_stop() self.record_button.setDisabled(False) def reset_model_download(self): @@ -588,12 +620,21 @@ class RecordingTranscriberWidget(QWidget): def reset_recording_controls(self): # Clear text box placeholder because the first chunk takes a while to process self.text_box.setPlaceholderText('') - self.timer_label.start_timer() self.record_button.setDisabled(False) if self.model_download_progress_dialog is not None: self.model_download_progress_dialog.close() self.model_download_progress_dialog = None + def on_recording_amplitude_changed(self, amplitude: float): + self.audio_meter_widget.update_amplitude(amplitude) + + def closeEvent(self, event: QCloseEvent) -> None: + self.stop_recording() + if self.recording_amplitude_listener is not None: + self.recording_amplitude_listener.stop_recording() + self.recording_amplitude_listener.deleteLater() + return super().closeEvent(event) + def get_asset_path(path: str): if getattr(sys, 'frozen', False): @@ -800,9 +841,8 @@ class MainWindowToolbar(QToolBar): return QIcon(pixmap) def on_record_action_triggered(self): - recording_transcriber_window = RecordingTranscriberWidget( - self, flags=Qt.WindowType.Window) - recording_transcriber_window.show() + recording_transcriber_window = RecordingTranscriberWidget(self) + recording_transcriber_window.exec() def set_open_transcript_action_disabled(self, disabled: bool): self.open_transcript_action.setDisabled(disabled) diff --git a/buzz/recording.py b/buzz/recording.py new file mode 100644 index 0000000..9673261 --- /dev/null +++ b/buzz/recording.py @@ -0,0 +1,30 @@ +from typing import Optional + +import numpy as np +import sounddevice +from PyQt6.QtCore import QObject, pyqtSignal + + +class RecordingAmplitudeListener(QObject): + stream: Optional[sounddevice.InputStream] = None + amplitude_changed = pyqtSignal(float) + + def __init__(self, input_device_index: Optional[int] = None, + parent: Optional[QObject] = None, + ): + super().__init__(parent) + self.input_device_index = input_device_index + + def start_recording(self): + self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32', + channels=1, callback=self.stream_callback) + self.stream.start() + + def stop_recording(self): + self.stream.stop() + self.stream.close() + + def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status): + chunk = in_data.ravel() + amplitude = np.sqrt(np.mean(chunk ** 2)) # root-mean-square + self.amplitude_changed.emit(amplitude) diff --git a/buzz/transcriber.py b/buzz/transcriber.py index c12af10..563cbf1 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -96,13 +96,12 @@ class RecordingTranscriber(QObject): MAX_QUEUE_SIZE = 10 def __init__(self, transcription_options: TranscriptionOptions, - input_device_index: Optional[int], parent: Optional[QObject] = None) -> None: + input_device_index: Optional[int], sample_rate: int, parent: Optional[QObject] = None) -> None: super().__init__(parent) self.transcription_options = transcription_options self.current_stream = None self.input_device_index = input_device_index - self.sample_rate = self.get_device_sample_rate( - device_id=input_device_index) + self.sample_rate = sample_rate self.n_batch_samples = 5 * self.sample_rate # every 5 seconds # pause queueing if more than 3 batches behind self.max_queue_size = 3 * self.n_batch_samples @@ -125,7 +124,6 @@ class RecordingTranscriber(QObject): self.is_running = True with sounddevice.InputStream(samplerate=self.sample_rate, - blocksize=1 * self.sample_rate, # 1 sec device=self.input_device_index, dtype="float32", channels=1, callback=self.stream_callback): while self.is_running: @@ -288,8 +286,6 @@ class WhisperCppFileTranscriber(QObject): def read_std_out(self): try: output = self.process.readAllStandardOutput().data().decode('UTF-8').strip() - logging.debug('whisper_cpp (output): %s', output) - if len(output) > 0: lines = output.split('\n') for line in lines: