Add recording amplitude indicator (#282)

This commit is contained in:
Chidi Williams 2023-01-01 23:21:06 +00:00 committed by GitHub
parent 5f9045c347
commit 820594ee35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 163 additions and 96 deletions

View file

@ -7,4 +7,4 @@ omit =
directory = coverage/html
[report]
fail_under = 73
fail_under = 72

View file

@ -25,6 +25,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- uses: actions/setup-python@v4
with:
python-version: '3.10.7'
@ -49,6 +50,7 @@ jobs:
path: |
~/Library/Caches/Buzz
~/.cache/whisper
~/.cache/huggingface
~/AppData/Local/Buzz/Buzz/Cache
key: whisper-models-${{ runner.os }}
@ -68,7 +70,6 @@ jobs:
with:
flags: ${{ runner.os }}
build:
runs-on: ${{ matrix.os }}
strategy:
@ -169,7 +170,7 @@ jobs:
include:
- os: macos-latest
- os: windows-latest
needs: [ build, test ]
needs: [build, test]
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v3

View file

@ -1,4 +1,3 @@
import dataclasses
import enum
import json
import logging
@ -7,12 +6,13 @@ import platform
import random
import sys
from datetime import datetime
from enum import auto
from typing import Dict, List, Optional, Tuple, Union
import humanize
import sounddevice
from PyQt6 import QtGui
from PyQt6.QtCore import (QDateTime, QObject, Qt, QThread,
from PyQt6.QtCore import (QObject, Qt, QThread,
QTimer, QUrl, pyqtSignal, QModelIndex, QSize, QPoint,
QUrlQuery, QMetaObject, QEvent)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
@ -23,13 +23,14 @@ from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QMainWindow, QMessageBox, QPlainTextEdit,
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu,
QWidget, QGroupBox, QToolBar, QTableWidget, QMenuBar, QFormLayout, QTableWidgetItem,
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QToolButton)
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QToolButton, QSizePolicy)
from requests import get
from whisper import tokenizer
from buzz.cache import TasksCache
from .__version__ import VERSION
from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel
from .recording import RecordingAmplitudeListener
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
Task,
WhisperCppFileTranscriber, WhisperFileTranscriber,
@ -151,40 +152,18 @@ class TextDisplayBox(QPlainTextEdit):
class RecordButton(QPushButton):
class Status(enum.Enum):
RECORDING = enum.auto()
STOPPED = enum.auto()
def __init__(self, parent: Optional[QWidget]) -> None:
super().__init__("Record", parent)
self.setDefault(True)
self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
current_status = Status.STOPPED
status_changed = pyqtSignal(Status)
def __init__(self, parent: Optional[QWidget], *args) -> None:
super().__init__("Record", parent, *args)
self.clicked.connect(self.on_click_record)
self.status_changed.connect(self.on_status_changed)
def set_to_record(self):
self.setText('Record')
self.setDefault(True)
def on_click_record(self):
current_status: RecordButton.Status
if self.current_status == self.Status.RECORDING:
current_status = self.Status.STOPPED
else:
current_status = self.Status.RECORDING
self.status_changed.emit(current_status)
# TODO: control the text and status from the caller
def on_status_changed(self, status: Status):
self.current_status = status
if status == self.Status.RECORDING:
self.setText('Stop')
self.setDefault(False)
else:
self.setText('Record')
self.setDefault(True)
def force_stop(self):
self.on_status_changed(self.Status.STOPPED)
def set_to_stop(self):
self.setText('Stop')
self.setDefault(False)
class DownloadModelProgressDialog(QProgressDialog):
@ -208,37 +187,6 @@ class DownloadModelProgressDialog(QProgressDialog):
f'Downloading model ({fraction_completed :.0%}, {humanize.naturaldelta(time_left)} remaining)')
class TimerLabel(QLabel):
start_time: Optional[QDateTime]
def __init__(self, parent: Optional[QWidget]):
super().__init__(parent)
self.timer = QTimer(self)
self.timer.timeout.connect(self.on_next_interval)
self.on_next_interval(stopped=True)
self.setAlignment(Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight))
def start_timer(self):
self.timer.start(1000)
self.start_time = QDateTime.currentDateTimeUtc()
self.on_next_interval()
def stop_timer(self):
self.timer.stop()
self.on_next_interval(stopped=True)
def on_next_interval(self, stopped=False):
if stopped:
self.setText('--:--')
elif self.start_time != None:
seconds_passed = self.start_time.secsTo(
QDateTime.currentDateTimeUtc())
self.setText('{0:02}:{1:02}'.format(
seconds_passed // 60, seconds_passed % 60))
def show_model_download_error_dialog(parent: QWidget, error: str):
message = f"An error occurred while loading the Whisper model: {error}{'' if error.endswith('.') else '.'}" \
f'information. '
@ -432,20 +380,84 @@ class AdvancedSettingsButton(QPushButton):
super().__init__('Advanced...', parent)
class RecordingTranscriberWidget(QWidget):
current_status = RecordButton.Status.STOPPED
class AudioMeterWidget(QWidget):
current_amplitude: float
BAR_WIDTH = 2
BAR_MARGIN = 1
BAR_INACTIVE_COLOR: QColor
BAR_ACTIVE_COLOR: QColor
# Factor by which the amplitude is scaled to make the changes more visible
DIFF_MULTIPLIER_FACTOR = 10
SMOOTHING_FACTOR = 0.95
def __init__(self, parent: Optional[QWidget] = None):
super().__init__(parent)
self.setMinimumWidth(10)
self.setFixedHeight(16)
# Extra padding to fix layout
self.PADDING_TOP = 3
self.current_amplitude = 0.0
self.MINIMUM_AMPLITUDE = 0.00005 # minimum amplitude to show the first bar
self.AMPLITUDE_SCALE_FACTOR = 15 # scale the amplitudes such that 1/AMPLITUDE_SCALE_FACTOR will show all bars
if self.palette().window().color().black() > 127:
self.BAR_INACTIVE_COLOR = QColor('#555')
self.BAR_ACTIVE_COLOR = QColor('#999')
else:
self.BAR_INACTIVE_COLOR = QColor('#BBB')
self.BAR_ACTIVE_COLOR = QColor('#555')
def paintEvent(self, event: QtGui.QPaintEvent) -> None:
painter = QPainter(self)
painter.setPen(Qt.PenStyle.NoPen)
rect = self.rect()
center_x = rect.center().x()
num_bars_in_half = int((rect.width() / 2) / (self.BAR_MARGIN + self.BAR_WIDTH))
for i in range(num_bars_in_half):
is_bar_active = ((self.current_amplitude - self.MINIMUM_AMPLITUDE) * self.AMPLITUDE_SCALE_FACTOR) > (
i / num_bars_in_half)
painter.setBrush(self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR)
# draw to left
painter.drawRect(center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)), rect.top() + self.PADDING_TOP,
self.BAR_WIDTH,
rect.height() - self.PADDING_TOP)
# draw to right
painter.drawRect(center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))),
rect.top() + self.PADDING_TOP,
self.BAR_WIDTH, rect.height() - self.PADDING_TOP)
def update_amplitude(self, amplitude: float):
self.current_amplitude = max(amplitude, self.current_amplitude * self.SMOOTHING_FACTOR)
self.repaint()
class RecordingTranscriberWidget(QDialog):
current_status: 'RecordingStatus'
transcription_options: TranscriptionOptions
selected_device_id: Optional[int]
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
transcriber: Optional[RecordingTranscriber] = None
model_loader: Optional[ModelLoader] = None
transcription_thread: Optional[QThread] = None
recording_amplitude_listener: Optional[RecordingAmplitudeListener] = None
device_sample_rate: Optional[int] = None
def __init__(self, parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
super().__init__(parent, flags)
class RecordingStatus(enum.Enum):
STOPPED = auto()
RECORDING = auto()
def __init__(self, parent: Optional[QWidget] = None) -> None:
super().__init__(parent)
layout = QVBoxLayout(self)
self.current_status = self.RecordingStatus.STOPPED
self.setWindowTitle('Live Recording')
self.transcription_options = TranscriptionOptions(
@ -457,10 +469,8 @@ class RecordingTranscriberWidget(QWidget):
self.on_device_changed)
self.selected_device_id = self.audio_devices_combo_box.get_default_device_id()
self.timer_label = TimerLabel(self)
self.record_button = RecordButton(self)
self.record_button.status_changed.connect(self.on_status_changed)
self.record_button.clicked.connect(self.on_record_button_clicked)
self.text_box = TextDisplayBox(self)
self.text_box.setPlaceholderText('Click Record to begin...')
@ -474,9 +484,10 @@ class RecordingTranscriberWidget(QWidget):
recording_options_layout.addRow(
'Microphone:', self.audio_devices_combo_box)
self.audio_meter_widget = AudioMeterWidget(self)
record_button_layout = QHBoxLayout()
record_button_layout.addStretch()
record_button_layout.addWidget(self.timer_label)
record_button_layout.addWidget(self.audio_meter_widget)
record_button_layout.addWidget(self.record_button)
layout.addWidget(transcription_options_group_box)
@ -487,21 +498,42 @@ class RecordingTranscriberWidget(QWidget):
self.setLayout(layout)
self.setFixedSize(self.sizeHint())
def closeEvent(self, event: QCloseEvent) -> None:
self.stop_recording()
return super().closeEvent(event)
self.reset_recording_amplitude_listener()
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
self.transcription_options = transcription_options
def on_device_changed(self, device_id: int):
self.selected_device_id = device_id
self.reset_recording_amplitude_listener()
def on_status_changed(self, status: RecordButton.Status):
if status == RecordButton.Status.RECORDING:
def reset_recording_amplitude_listener(self):
if self.recording_amplitude_listener is not None:
self.recording_amplitude_listener.stop_recording()
self.recording_amplitude_listener.deleteLater()
# Listening to audio will fail if there are no input devices
if self.selected_device_id is None or self.selected_device_id == -1:
return
# Get the device sample rate before starting the listener as the PortAudio function
# fails if you try to get the device's settings while recording is in progress.
self.device_sample_rate = RecordingTranscriber.get_device_sample_rate(self.selected_device_id)
self.recording_amplitude_listener = RecordingAmplitudeListener(input_device_index=self.selected_device_id,
parent=self)
self.recording_amplitude_listener.amplitude_changed.connect(self.on_recording_amplitude_changed)
self.recording_amplitude_listener.start_recording()
def on_record_button_clicked(self):
if self.current_status == self.RecordingStatus.STOPPED:
self.start_recording()
else:
self.current_status = self.RecordingStatus.RECORDING
self.record_button.set_to_stop()
else: # RecordingStatus.RECORDING
self.stop_recording()
self.record_button.set_to_record()
self.current_status = self.RecordingStatus.STOPPED
def start_recording(self):
self.record_button.setDisabled(True)
@ -510,6 +542,7 @@ class RecordingTranscriberWidget(QWidget):
self.model_loader = ModelLoader(model=self.transcription_options.model)
self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id,
sample_rate=self.device_sample_rate,
transcription_options=self.transcription_options)
self.model_loader.moveToThread(self.transcription_thread)
@ -551,7 +584,7 @@ class RecordingTranscriberWidget(QWidget):
self.reset_model_download()
show_model_download_error_dialog(self, error)
self.stop_recording()
self.record_button.force_stop()
self.record_button.set_to_stop()
self.record_button.setDisabled(False)
def on_next_transcription(self, text: str):
@ -568,7 +601,6 @@ class RecordingTranscriberWidget(QWidget):
self.transcriber.stop_recording()
# Disable record button until the transcription is actually stopped in the background
self.record_button.setDisabled(True)
self.timer_label.stop_timer()
def on_transcriber_finished(self):
self.record_button.setEnabled(True)
@ -577,7 +609,7 @@ class RecordingTranscriberWidget(QWidget):
if self.model_loader is not None:
self.model_loader.stop()
self.reset_model_download()
self.record_button.force_stop()
self.record_button.set_to_stop()
self.record_button.setDisabled(False)
def reset_model_download(self):
@ -588,12 +620,21 @@ class RecordingTranscriberWidget(QWidget):
def reset_recording_controls(self):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.timer_label.start_timer()
self.record_button.setDisabled(False)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None
def on_recording_amplitude_changed(self, amplitude: float):
self.audio_meter_widget.update_amplitude(amplitude)
def closeEvent(self, event: QCloseEvent) -> None:
self.stop_recording()
if self.recording_amplitude_listener is not None:
self.recording_amplitude_listener.stop_recording()
self.recording_amplitude_listener.deleteLater()
return super().closeEvent(event)
def get_asset_path(path: str):
if getattr(sys, 'frozen', False):
@ -800,9 +841,8 @@ class MainWindowToolbar(QToolBar):
return QIcon(pixmap)
def on_record_action_triggered(self):
recording_transcriber_window = RecordingTranscriberWidget(
self, flags=Qt.WindowType.Window)
recording_transcriber_window.show()
recording_transcriber_window = RecordingTranscriberWidget(self)
recording_transcriber_window.exec()
def set_open_transcript_action_disabled(self, disabled: bool):
self.open_transcript_action.setDisabled(disabled)

30
buzz/recording.py Normal file
View file

@ -0,0 +1,30 @@
from typing import Optional
import numpy as np
import sounddevice
from PyQt6.QtCore import QObject, pyqtSignal
class RecordingAmplitudeListener(QObject):
stream: Optional[sounddevice.InputStream] = None
amplitude_changed = pyqtSignal(float)
def __init__(self, input_device_index: Optional[int] = None,
parent: Optional[QObject] = None,
):
super().__init__(parent)
self.input_device_index = input_device_index
def start_recording(self):
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
channels=1, callback=self.stream_callback)
self.stream.start()
def stop_recording(self):
self.stream.stop()
self.stream.close()
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
chunk = in_data.ravel()
amplitude = np.sqrt(np.mean(chunk ** 2)) # root-mean-square
self.amplitude_changed.emit(amplitude)

View file

@ -96,13 +96,12 @@ class RecordingTranscriber(QObject):
MAX_QUEUE_SIZE = 10
def __init__(self, transcription_options: TranscriptionOptions,
input_device_index: Optional[int], parent: Optional[QObject] = None) -> None:
input_device_index: Optional[int], sample_rate: int, parent: Optional[QObject] = None) -> None:
super().__init__(parent)
self.transcription_options = transcription_options
self.current_stream = None
self.input_device_index = input_device_index
self.sample_rate = self.get_device_sample_rate(
device_id=input_device_index)
self.sample_rate = sample_rate
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
# pause queueing if more than 3 batches behind
self.max_queue_size = 3 * self.n_batch_samples
@ -125,7 +124,6 @@ class RecordingTranscriber(QObject):
self.is_running = True
with sounddevice.InputStream(samplerate=self.sample_rate,
blocksize=1 * self.sample_rate, # 1 sec
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
@ -288,8 +286,6 @@ class WhisperCppFileTranscriber(QObject):
def read_std_out(self):
try:
output = self.process.readAllStandardOutput().data().decode('UTF-8').strip()
logging.debug('whisper_cpp (output): %s', output)
if len(output) > 0:
lines = output.split('\n')
for line in lines: