mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 14:45:46 +01:00
179 lines
7.2 KiB
Python
179 lines
7.2 KiB
Python
import datetime
|
|
import logging
|
|
import sys
|
|
import threading
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import sounddevice
|
|
from PyQt6.QtCore import QObject, pyqtSignal
|
|
from sounddevice import PortAudioError
|
|
|
|
from buzz import transformers_whisper, whisper_audio
|
|
from buzz.model_loader import ModelType
|
|
from buzz.transcriber.transcriber import TranscriptionOptions
|
|
from buzz.transcriber.whisper_cpp import WhisperCpp, whisper_cpp_params
|
|
from buzz.transformers_whisper import TransformersWhisper
|
|
|
|
if sys.platform != "linux":
|
|
import whisper
|
|
|
|
|
|
class RecordingTranscriber(QObject):
|
|
transcription = pyqtSignal(str)
|
|
finished = pyqtSignal()
|
|
error = pyqtSignal(str)
|
|
is_running = False
|
|
SAMPLE_RATE = whisper_audio.SAMPLE_RATE
|
|
MAX_QUEUE_SIZE = 10
|
|
|
|
def __init__(
|
|
self,
|
|
transcription_options: TranscriptionOptions,
|
|
input_device_index: Optional[int],
|
|
sample_rate: int,
|
|
model_path: str,
|
|
parent: Optional[QObject] = None,
|
|
) -> None:
|
|
super().__init__(parent)
|
|
self.transcription_options = transcription_options
|
|
self.current_stream = None
|
|
self.input_device_index = input_device_index
|
|
self.sample_rate = sample_rate
|
|
self.model_path = model_path
|
|
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
|
|
# pause queueing if more than 3 batches behind
|
|
self.max_queue_size = 3 * self.n_batch_samples
|
|
self.queue = np.ndarray([], dtype=np.float32)
|
|
self.mutex = threading.Lock()
|
|
|
|
def start(self):
|
|
model_path = self.model_path
|
|
|
|
if self.transcription_options.model.model_type == ModelType.WHISPER:
|
|
model = whisper.load_model(model_path)
|
|
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
|
|
model = WhisperCpp(model_path)
|
|
else: # ModelType.HUGGING_FACE
|
|
model = transformers_whisper.load_model(model_path)
|
|
|
|
initial_prompt = self.transcription_options.initial_prompt
|
|
|
|
logging.debug(
|
|
"Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s",
|
|
self.transcription_options,
|
|
model_path,
|
|
self.sample_rate,
|
|
self.input_device_index,
|
|
)
|
|
|
|
self.is_running = True
|
|
try:
|
|
with sounddevice.InputStream(
|
|
samplerate=self.sample_rate,
|
|
device=self.input_device_index,
|
|
dtype="float32",
|
|
channels=1,
|
|
callback=self.stream_callback,
|
|
):
|
|
while self.is_running:
|
|
self.mutex.acquire()
|
|
if self.queue.size >= self.n_batch_samples:
|
|
samples = self.queue[: self.n_batch_samples]
|
|
self.queue = self.queue[self.n_batch_samples :]
|
|
self.mutex.release()
|
|
|
|
logging.debug(
|
|
"Processing next frame, sample size = %s, queue size = %s, amplitude = %s",
|
|
samples.size,
|
|
self.queue.size,
|
|
self.amplitude(samples),
|
|
)
|
|
time_started = datetime.datetime.now()
|
|
|
|
if (
|
|
self.transcription_options.model.model_type
|
|
== ModelType.WHISPER
|
|
):
|
|
assert isinstance(model, whisper.Whisper)
|
|
result = model.transcribe(
|
|
audio=samples,
|
|
language=self.transcription_options.language,
|
|
task=self.transcription_options.task.value,
|
|
initial_prompt=initial_prompt,
|
|
temperature=self.transcription_options.temperature,
|
|
)
|
|
elif (
|
|
self.transcription_options.model.model_type
|
|
== ModelType.WHISPER_CPP
|
|
):
|
|
assert isinstance(model, WhisperCpp)
|
|
result = model.transcribe(
|
|
audio=samples,
|
|
params=whisper_cpp_params(
|
|
language=self.transcription_options.language
|
|
if self.transcription_options.language is not None
|
|
else "en",
|
|
task=self.transcription_options.task.value,
|
|
word_level_timings=False,
|
|
),
|
|
)
|
|
else:
|
|
assert isinstance(model, TransformersWhisper)
|
|
result = model.transcribe(
|
|
audio=samples,
|
|
language=self.transcription_options.language
|
|
if self.transcription_options.language is not None
|
|
else "en",
|
|
task=self.transcription_options.task.value,
|
|
)
|
|
|
|
next_text: str = result.get("text")
|
|
|
|
# Update initial prompt between successive recording chunks
|
|
initial_prompt += next_text
|
|
|
|
logging.debug(
|
|
"Received next result, length = %s, time taken = %s",
|
|
len(next_text),
|
|
datetime.datetime.now() - time_started,
|
|
)
|
|
self.transcription.emit(next_text)
|
|
else:
|
|
self.mutex.release()
|
|
except PortAudioError as exc:
|
|
self.error.emit(str(exc))
|
|
logging.exception("")
|
|
return
|
|
|
|
self.finished.emit()
|
|
|
|
@staticmethod
|
|
def get_device_sample_rate(device_id: Optional[int]) -> int:
|
|
"""Returns the sample rate to be used for recording. It uses the default sample rate
|
|
provided by Whisper if the microphone supports it, or else it uses the device's default
|
|
sample rate.
|
|
"""
|
|
sample_rate = whisper_audio.SAMPLE_RATE
|
|
try:
|
|
sounddevice.check_input_settings(device=device_id, samplerate=sample_rate)
|
|
return sample_rate
|
|
except PortAudioError:
|
|
device_info = sounddevice.query_devices(device=device_id)
|
|
if isinstance(device_info, dict):
|
|
return int(device_info.get("default_samplerate", sample_rate))
|
|
return sample_rate
|
|
|
|
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
|
|
# Try to enqueue the next block. If the queue is already full, drop the block.
|
|
chunk: np.ndarray = in_data.ravel()
|
|
with self.mutex:
|
|
if self.queue.size < self.max_queue_size:
|
|
self.queue = np.append(self.queue, chunk)
|
|
|
|
@staticmethod
|
|
def amplitude(arr: np.ndarray):
|
|
return (abs(max(arr)) + abs(min(arr))) / 2
|
|
|
|
def stop_recording(self):
|
|
self.is_running = False
|