mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-26 11:40:09 +02:00
Add recording amplitude indicator (#282)
This commit is contained in:
parent
5f9045c347
commit
820594ee35
|
@ -7,4 +7,4 @@ omit =
|
|||
directory = coverage/html
|
||||
|
||||
[report]
|
||||
fail_under = 73
|
||||
fail_under = 72
|
||||
|
|
5
.github/workflows/ci.yml
vendored
5
.github/workflows/ci.yml
vendored
|
@ -25,6 +25,7 @@ jobs:
|
|||
- uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10.7'
|
||||
|
@ -49,6 +50,7 @@ jobs:
|
|||
path: |
|
||||
~/Library/Caches/Buzz
|
||||
~/.cache/whisper
|
||||
~/.cache/huggingface
|
||||
~/AppData/Local/Buzz/Buzz/Cache
|
||||
key: whisper-models-${{ runner.os }}
|
||||
|
||||
|
@ -68,7 +70,6 @@ jobs:
|
|||
with:
|
||||
flags: ${{ runner.os }}
|
||||
|
||||
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
|
@ -169,7 +170,7 @@ jobs:
|
|||
include:
|
||||
- os: macos-latest
|
||||
- os: windows-latest
|
||||
needs: [ build, test ]
|
||||
needs: [build, test]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
|
214
buzz/gui.py
214
buzz/gui.py
|
@ -1,4 +1,3 @@
|
|||
import dataclasses
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
|
@ -7,12 +6,13 @@ import platform
|
|||
import random
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from enum import auto
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import humanize
|
||||
import sounddevice
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (QDateTime, QObject, Qt, QThread,
|
||||
from PyQt6.QtCore import (QObject, Qt, QThread,
|
||||
QTimer, QUrl, pyqtSignal, QModelIndex, QSize, QPoint,
|
||||
QUrlQuery, QMetaObject, QEvent)
|
||||
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
||||
|
@ -23,13 +23,14 @@ from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
|||
QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu,
|
||||
QWidget, QGroupBox, QToolBar, QTableWidget, QMenuBar, QFormLayout, QTableWidgetItem,
|
||||
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QToolButton)
|
||||
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QToolButton, QSizePolicy)
|
||||
from requests import get
|
||||
from whisper import tokenizer
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
from .__version__ import VERSION
|
||||
from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel
|
||||
from .recording import RecordingAmplitudeListener
|
||||
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
|
||||
Task,
|
||||
WhisperCppFileTranscriber, WhisperFileTranscriber,
|
||||
|
@ -151,40 +152,18 @@ class TextDisplayBox(QPlainTextEdit):
|
|||
|
||||
|
||||
class RecordButton(QPushButton):
|
||||
class Status(enum.Enum):
|
||||
RECORDING = enum.auto()
|
||||
STOPPED = enum.auto()
|
||||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
super().__init__("Record", parent)
|
||||
self.setDefault(True)
|
||||
self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
|
||||
|
||||
current_status = Status.STOPPED
|
||||
status_changed = pyqtSignal(Status)
|
||||
|
||||
def __init__(self, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__("Record", parent, *args)
|
||||
self.clicked.connect(self.on_click_record)
|
||||
self.status_changed.connect(self.on_status_changed)
|
||||
def set_to_record(self):
|
||||
self.setText('Record')
|
||||
self.setDefault(True)
|
||||
|
||||
def on_click_record(self):
|
||||
current_status: RecordButton.Status
|
||||
if self.current_status == self.Status.RECORDING:
|
||||
current_status = self.Status.STOPPED
|
||||
else:
|
||||
current_status = self.Status.RECORDING
|
||||
|
||||
self.status_changed.emit(current_status)
|
||||
|
||||
# TODO: control the text and status from the caller
|
||||
def on_status_changed(self, status: Status):
|
||||
self.current_status = status
|
||||
if status == self.Status.RECORDING:
|
||||
self.setText('Stop')
|
||||
self.setDefault(False)
|
||||
else:
|
||||
self.setText('Record')
|
||||
self.setDefault(True)
|
||||
|
||||
def force_stop(self):
|
||||
self.on_status_changed(self.Status.STOPPED)
|
||||
def set_to_stop(self):
|
||||
self.setText('Stop')
|
||||
self.setDefault(False)
|
||||
|
||||
|
||||
class DownloadModelProgressDialog(QProgressDialog):
|
||||
|
@ -208,37 +187,6 @@ class DownloadModelProgressDialog(QProgressDialog):
|
|||
f'Downloading model ({fraction_completed :.0%}, {humanize.naturaldelta(time_left)} remaining)')
|
||||
|
||||
|
||||
class TimerLabel(QLabel):
|
||||
start_time: Optional[QDateTime]
|
||||
|
||||
def __init__(self, parent: Optional[QWidget]):
|
||||
super().__init__(parent)
|
||||
|
||||
self.timer = QTimer(self)
|
||||
self.timer.timeout.connect(self.on_next_interval)
|
||||
self.on_next_interval(stopped=True)
|
||||
self.setAlignment(Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight))
|
||||
|
||||
def start_timer(self):
|
||||
self.timer.start(1000)
|
||||
self.start_time = QDateTime.currentDateTimeUtc()
|
||||
self.on_next_interval()
|
||||
|
||||
def stop_timer(self):
|
||||
self.timer.stop()
|
||||
self.on_next_interval(stopped=True)
|
||||
|
||||
def on_next_interval(self, stopped=False):
|
||||
if stopped:
|
||||
self.setText('--:--')
|
||||
elif self.start_time != None:
|
||||
seconds_passed = self.start_time.secsTo(
|
||||
QDateTime.currentDateTimeUtc())
|
||||
self.setText('{0:02}:{1:02}'.format(
|
||||
seconds_passed // 60, seconds_passed % 60))
|
||||
|
||||
|
||||
def show_model_download_error_dialog(parent: QWidget, error: str):
|
||||
message = f"An error occurred while loading the Whisper model: {error}{'' if error.endswith('.') else '.'}" \
|
||||
f'information. '
|
||||
|
@ -432,20 +380,84 @@ class AdvancedSettingsButton(QPushButton):
|
|||
super().__init__('Advanced...', parent)
|
||||
|
||||
|
||||
class RecordingTranscriberWidget(QWidget):
|
||||
current_status = RecordButton.Status.STOPPED
|
||||
class AudioMeterWidget(QWidget):
|
||||
current_amplitude: float
|
||||
BAR_WIDTH = 2
|
||||
BAR_MARGIN = 1
|
||||
BAR_INACTIVE_COLOR: QColor
|
||||
BAR_ACTIVE_COLOR: QColor
|
||||
|
||||
# Factor by which the amplitude is scaled to make the changes more visible
|
||||
DIFF_MULTIPLIER_FACTOR = 10
|
||||
SMOOTHING_FACTOR = 0.95
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
self.setMinimumWidth(10)
|
||||
self.setFixedHeight(16)
|
||||
|
||||
# Extra padding to fix layout
|
||||
self.PADDING_TOP = 3
|
||||
|
||||
self.current_amplitude = 0.0
|
||||
|
||||
self.MINIMUM_AMPLITUDE = 0.00005 # minimum amplitude to show the first bar
|
||||
self.AMPLITUDE_SCALE_FACTOR = 15 # scale the amplitudes such that 1/AMPLITUDE_SCALE_FACTOR will show all bars
|
||||
|
||||
if self.palette().window().color().black() > 127:
|
||||
self.BAR_INACTIVE_COLOR = QColor('#555')
|
||||
self.BAR_ACTIVE_COLOR = QColor('#999')
|
||||
else:
|
||||
self.BAR_INACTIVE_COLOR = QColor('#BBB')
|
||||
self.BAR_ACTIVE_COLOR = QColor('#555')
|
||||
|
||||
def paintEvent(self, event: QtGui.QPaintEvent) -> None:
|
||||
painter = QPainter(self)
|
||||
painter.setPen(Qt.PenStyle.NoPen)
|
||||
|
||||
rect = self.rect()
|
||||
center_x = rect.center().x()
|
||||
num_bars_in_half = int((rect.width() / 2) / (self.BAR_MARGIN + self.BAR_WIDTH))
|
||||
for i in range(num_bars_in_half):
|
||||
is_bar_active = ((self.current_amplitude - self.MINIMUM_AMPLITUDE) * self.AMPLITUDE_SCALE_FACTOR) > (
|
||||
i / num_bars_in_half)
|
||||
painter.setBrush(self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR)
|
||||
|
||||
# draw to left
|
||||
painter.drawRect(center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)), rect.top() + self.PADDING_TOP,
|
||||
self.BAR_WIDTH,
|
||||
rect.height() - self.PADDING_TOP)
|
||||
# draw to right
|
||||
painter.drawRect(center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))),
|
||||
rect.top() + self.PADDING_TOP,
|
||||
self.BAR_WIDTH, rect.height() - self.PADDING_TOP)
|
||||
|
||||
def update_amplitude(self, amplitude: float):
|
||||
self.current_amplitude = max(amplitude, self.current_amplitude * self.SMOOTHING_FACTOR)
|
||||
self.repaint()
|
||||
|
||||
|
||||
class RecordingTranscriberWidget(QDialog):
|
||||
current_status: 'RecordingStatus'
|
||||
transcription_options: TranscriptionOptions
|
||||
selected_device_id: Optional[int]
|
||||
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
transcriber: Optional[RecordingTranscriber] = None
|
||||
model_loader: Optional[ModelLoader] = None
|
||||
transcription_thread: Optional[QThread] = None
|
||||
recording_amplitude_listener: Optional[RecordingAmplitudeListener] = None
|
||||
device_sample_rate: Optional[int] = None
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
|
||||
super().__init__(parent, flags)
|
||||
class RecordingStatus(enum.Enum):
|
||||
STOPPED = auto()
|
||||
RECORDING = auto()
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
self.current_status = self.RecordingStatus.STOPPED
|
||||
self.setWindowTitle('Live Recording')
|
||||
|
||||
self.transcription_options = TranscriptionOptions(
|
||||
|
@ -457,10 +469,8 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.on_device_changed)
|
||||
self.selected_device_id = self.audio_devices_combo_box.get_default_device_id()
|
||||
|
||||
self.timer_label = TimerLabel(self)
|
||||
|
||||
self.record_button = RecordButton(self)
|
||||
self.record_button.status_changed.connect(self.on_status_changed)
|
||||
self.record_button.clicked.connect(self.on_record_button_clicked)
|
||||
|
||||
self.text_box = TextDisplayBox(self)
|
||||
self.text_box.setPlaceholderText('Click Record to begin...')
|
||||
|
@ -474,9 +484,10 @@ class RecordingTranscriberWidget(QWidget):
|
|||
recording_options_layout.addRow(
|
||||
'Microphone:', self.audio_devices_combo_box)
|
||||
|
||||
self.audio_meter_widget = AudioMeterWidget(self)
|
||||
|
||||
record_button_layout = QHBoxLayout()
|
||||
record_button_layout.addStretch()
|
||||
record_button_layout.addWidget(self.timer_label)
|
||||
record_button_layout.addWidget(self.audio_meter_widget)
|
||||
record_button_layout.addWidget(self.record_button)
|
||||
|
||||
layout.addWidget(transcription_options_group_box)
|
||||
|
@ -487,21 +498,42 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.setLayout(layout)
|
||||
self.setFixedSize(self.sizeHint())
|
||||
|
||||
def closeEvent(self, event: QCloseEvent) -> None:
|
||||
self.stop_recording()
|
||||
return super().closeEvent(event)
|
||||
self.reset_recording_amplitude_listener()
|
||||
|
||||
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
|
||||
self.transcription_options = transcription_options
|
||||
|
||||
def on_device_changed(self, device_id: int):
|
||||
self.selected_device_id = device_id
|
||||
self.reset_recording_amplitude_listener()
|
||||
|
||||
def on_status_changed(self, status: RecordButton.Status):
|
||||
if status == RecordButton.Status.RECORDING:
|
||||
def reset_recording_amplitude_listener(self):
|
||||
if self.recording_amplitude_listener is not None:
|
||||
self.recording_amplitude_listener.stop_recording()
|
||||
self.recording_amplitude_listener.deleteLater()
|
||||
|
||||
# Listening to audio will fail if there are no input devices
|
||||
if self.selected_device_id is None or self.selected_device_id == -1:
|
||||
return
|
||||
|
||||
# Get the device sample rate before starting the listener as the PortAudio function
|
||||
# fails if you try to get the device's settings while recording is in progress.
|
||||
self.device_sample_rate = RecordingTranscriber.get_device_sample_rate(self.selected_device_id)
|
||||
|
||||
self.recording_amplitude_listener = RecordingAmplitudeListener(input_device_index=self.selected_device_id,
|
||||
parent=self)
|
||||
self.recording_amplitude_listener.amplitude_changed.connect(self.on_recording_amplitude_changed)
|
||||
self.recording_amplitude_listener.start_recording()
|
||||
|
||||
def on_record_button_clicked(self):
|
||||
if self.current_status == self.RecordingStatus.STOPPED:
|
||||
self.start_recording()
|
||||
else:
|
||||
self.current_status = self.RecordingStatus.RECORDING
|
||||
self.record_button.set_to_stop()
|
||||
else: # RecordingStatus.RECORDING
|
||||
self.stop_recording()
|
||||
self.record_button.set_to_record()
|
||||
self.current_status = self.RecordingStatus.STOPPED
|
||||
|
||||
def start_recording(self):
|
||||
self.record_button.setDisabled(True)
|
||||
|
@ -510,6 +542,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
self.model_loader = ModelLoader(model=self.transcription_options.model)
|
||||
self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id,
|
||||
sample_rate=self.device_sample_rate,
|
||||
transcription_options=self.transcription_options)
|
||||
|
||||
self.model_loader.moveToThread(self.transcription_thread)
|
||||
|
@ -551,7 +584,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.reset_model_download()
|
||||
show_model_download_error_dialog(self, error)
|
||||
self.stop_recording()
|
||||
self.record_button.force_stop()
|
||||
self.record_button.set_to_stop()
|
||||
self.record_button.setDisabled(False)
|
||||
|
||||
def on_next_transcription(self, text: str):
|
||||
|
@ -568,7 +601,6 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.transcriber.stop_recording()
|
||||
# Disable record button until the transcription is actually stopped in the background
|
||||
self.record_button.setDisabled(True)
|
||||
self.timer_label.stop_timer()
|
||||
|
||||
def on_transcriber_finished(self):
|
||||
self.record_button.setEnabled(True)
|
||||
|
@ -577,7 +609,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
if self.model_loader is not None:
|
||||
self.model_loader.stop()
|
||||
self.reset_model_download()
|
||||
self.record_button.force_stop()
|
||||
self.record_button.set_to_stop()
|
||||
self.record_button.setDisabled(False)
|
||||
|
||||
def reset_model_download(self):
|
||||
|
@ -588,12 +620,21 @@ class RecordingTranscriberWidget(QWidget):
|
|||
def reset_recording_controls(self):
|
||||
# Clear text box placeholder because the first chunk takes a while to process
|
||||
self.text_box.setPlaceholderText('')
|
||||
self.timer_label.start_timer()
|
||||
self.record_button.setDisabled(False)
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.close()
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
def on_recording_amplitude_changed(self, amplitude: float):
|
||||
self.audio_meter_widget.update_amplitude(amplitude)
|
||||
|
||||
def closeEvent(self, event: QCloseEvent) -> None:
|
||||
self.stop_recording()
|
||||
if self.recording_amplitude_listener is not None:
|
||||
self.recording_amplitude_listener.stop_recording()
|
||||
self.recording_amplitude_listener.deleteLater()
|
||||
return super().closeEvent(event)
|
||||
|
||||
|
||||
def get_asset_path(path: str):
|
||||
if getattr(sys, 'frozen', False):
|
||||
|
@ -800,9 +841,8 @@ class MainWindowToolbar(QToolBar):
|
|||
return QIcon(pixmap)
|
||||
|
||||
def on_record_action_triggered(self):
|
||||
recording_transcriber_window = RecordingTranscriberWidget(
|
||||
self, flags=Qt.WindowType.Window)
|
||||
recording_transcriber_window.show()
|
||||
recording_transcriber_window = RecordingTranscriberWidget(self)
|
||||
recording_transcriber_window.exec()
|
||||
|
||||
def set_open_transcript_action_disabled(self, disabled: bool):
|
||||
self.open_transcript_action.setDisabled(disabled)
|
||||
|
|
30
buzz/recording.py
Normal file
30
buzz/recording.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import sounddevice
|
||||
from PyQt6.QtCore import QObject, pyqtSignal
|
||||
|
||||
|
||||
class RecordingAmplitudeListener(QObject):
|
||||
stream: Optional[sounddevice.InputStream] = None
|
||||
amplitude_changed = pyqtSignal(float)
|
||||
|
||||
def __init__(self, input_device_index: Optional[int] = None,
|
||||
parent: Optional[QObject] = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.input_device_index = input_device_index
|
||||
|
||||
def start_recording(self):
|
||||
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
|
||||
channels=1, callback=self.stream_callback)
|
||||
self.stream.start()
|
||||
|
||||
def stop_recording(self):
|
||||
self.stream.stop()
|
||||
self.stream.close()
|
||||
|
||||
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
|
||||
chunk = in_data.ravel()
|
||||
amplitude = np.sqrt(np.mean(chunk ** 2)) # root-mean-square
|
||||
self.amplitude_changed.emit(amplitude)
|
|
@ -96,13 +96,12 @@ class RecordingTranscriber(QObject):
|
|||
MAX_QUEUE_SIZE = 10
|
||||
|
||||
def __init__(self, transcription_options: TranscriptionOptions,
|
||||
input_device_index: Optional[int], parent: Optional[QObject] = None) -> None:
|
||||
input_device_index: Optional[int], sample_rate: int, parent: Optional[QObject] = None) -> None:
|
||||
super().__init__(parent)
|
||||
self.transcription_options = transcription_options
|
||||
self.current_stream = None
|
||||
self.input_device_index = input_device_index
|
||||
self.sample_rate = self.get_device_sample_rate(
|
||||
device_id=input_device_index)
|
||||
self.sample_rate = sample_rate
|
||||
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
|
||||
# pause queueing if more than 3 batches behind
|
||||
self.max_queue_size = 3 * self.n_batch_samples
|
||||
|
@ -125,7 +124,6 @@ class RecordingTranscriber(QObject):
|
|||
|
||||
self.is_running = True
|
||||
with sounddevice.InputStream(samplerate=self.sample_rate,
|
||||
blocksize=1 * self.sample_rate, # 1 sec
|
||||
device=self.input_device_index, dtype="float32",
|
||||
channels=1, callback=self.stream_callback):
|
||||
while self.is_running:
|
||||
|
@ -288,8 +286,6 @@ class WhisperCppFileTranscriber(QObject):
|
|||
def read_std_out(self):
|
||||
try:
|
||||
output = self.process.readAllStandardOutput().data().decode('UTF-8').strip()
|
||||
logging.debug('whisper_cpp (output): %s', output)
|
||||
|
||||
if len(output) > 0:
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
|
|
Loading…
Reference in a new issue