Refactor transcribers class (#508)

This commit is contained in:
Chidi Williams 2023-06-26 08:46:17 +01:00 committed by GitHub
commit 2dc0797e64
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 143 additions and 131 deletions

View file

@ -35,8 +35,9 @@ from .store.keyring_store import KeyringStore
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
Task,
TranscriptionOptions,
FileTranscriptionTask, RecordingTranscriber, LOADED_WHISPER_DLL,
FileTranscriptionTask, LOADED_WHISPER_DLL,
DEFAULT_WHISPER_TEMPERATURE, LANGUAGES)
from .recording_transcriber import RecordingTranscriber
from .file_transcriber_queue_worker import FileTranscriberQueueWorker
from .widgets.line_edit import LineEdit
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog

View file

@ -0,0 +1,139 @@
import datetime
import logging
import threading
from typing import Optional
import numpy as np
import sounddevice
import whisper
from PyQt6.QtCore import QObject, pyqtSignal
from sounddevice import PortAudioError
from buzz import transformers_whisper
from buzz.model_loader import ModelType
from buzz.transcriber import TranscriptionOptions, WhisperCpp, whisper_cpp_params
from buzz.transformers_whisper import TransformersWhisper
class RecordingTranscriber(QObject):
transcription = pyqtSignal(str)
finished = pyqtSignal()
error = pyqtSignal(str)
is_running = False
MAX_QUEUE_SIZE = 10
def __init__(self, transcription_options: TranscriptionOptions,
input_device_index: Optional[int], sample_rate: int, model_path: str,
parent: Optional[QObject] = None) -> None:
super().__init__(parent)
self.transcription_options = transcription_options
self.current_stream = None
self.input_device_index = input_device_index
self.sample_rate = sample_rate
self.model_path = model_path
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
# pause queueing if more than 3 batches behind
self.max_queue_size = 3 * self.n_batch_samples
self.queue = np.ndarray([], dtype=np.float32)
self.mutex = threading.Lock()
def start(self):
model_path = self.model_path
if self.transcription_options.model.model_type == ModelType.WHISPER:
model = whisper.load_model(model_path)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
model = WhisperCpp(model_path)
else: # ModelType.HUGGING_FACE
model = transformers_whisper.load_model(model_path)
initial_prompt = self.transcription_options.initial_prompt
logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s',
self.transcription_options, model_path, self.sample_rate, self.input_device_index)
self.is_running = True
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
self.mutex.release()
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()
if self.transcription_options.model.model_type == ModelType.WHISPER:
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)
next_text: str = result.get('text')
# Update initial prompt between successive recording chunks
initial_prompt += next_text
logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
self.mutex.release()
except PortAudioError as exc:
self.error.emit(str(exc))
logging.exception('')
return
self.finished.emit()
@staticmethod
def get_device_sample_rate(device_id: Optional[int]) -> int:
"""Returns the sample rate to be used for recording. It uses the default sample rate
provided by Whisper if the microphone supports it, or else it uses the device's default
sample rate.
"""
whisper_sample_rate = whisper.audio.SAMPLE_RATE
try:
sounddevice.check_input_settings(
device=device_id, samplerate=whisper_sample_rate)
return whisper_sample_rate
except PortAudioError:
device_info = sounddevice.query_devices(device=device_id)
if isinstance(device_info, dict):
return int(device_info.get('default_samplerate', whisper_sample_rate))
return whisper_sample_rate
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
# Try to enqueue the next block. If the queue is already full, drop the block.
chunk: np.ndarray = in_data.ravel()
with self.mutex:
if self.queue.size < self.max_queue_size:
self.queue = np.append(self.queue, chunk)
@staticmethod
def amplitude(arr: np.ndarray):
return (abs(max(arr)) + abs(min(arr))) / 2
def stop_recording(self):
self.is_running = False

View file

@ -7,7 +7,6 @@ import multiprocessing
import os
import sys
import tempfile
import threading
from abc import abstractmethod
from dataclasses import dataclass, field
from multiprocessing.connection import Connection
@ -19,18 +18,15 @@ import faster_whisper
import ffmpeg
import numpy as np
import openai
import sounddevice
import stable_whisper
import tqdm
import whisper
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from sounddevice import PortAudioError
from whisper import tokenizer
from . import transformers_whisper
from .conn import pipe_stderr
from .model_loader import TranscriptionModel, ModelType
from .transformers_whisper import TransformersWhisper
# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
@ -101,130 +97,6 @@ class FileTranscriptionTask:
completed_at: Optional[datetime.datetime] = None
class RecordingTranscriber(QObject):
transcription = pyqtSignal(str)
finished = pyqtSignal()
error = pyqtSignal(str)
is_running = False
MAX_QUEUE_SIZE = 10
def __init__(self, transcription_options: TranscriptionOptions,
input_device_index: Optional[int], sample_rate: int, model_path: str,
parent: Optional[QObject] = None) -> None:
super().__init__(parent)
self.transcription_options = transcription_options
self.current_stream = None
self.input_device_index = input_device_index
self.sample_rate = sample_rate
self.model_path = model_path
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
# pause queueing if more than 3 batches behind
self.max_queue_size = 3 * self.n_batch_samples
self.queue = np.ndarray([], dtype=np.float32)
self.mutex = threading.Lock()
def start(self):
model_path = self.model_path
if self.transcription_options.model.model_type == ModelType.WHISPER:
model = whisper.load_model(model_path)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
model = WhisperCpp(model_path)
else: # ModelType.HUGGING_FACE
model = transformers_whisper.load_model(model_path)
initial_prompt = self.transcription_options.initial_prompt
logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s',
self.transcription_options, model_path, self.sample_rate, self.input_device_index)
self.is_running = True
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
self.mutex.release()
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()
if self.transcription_options.model.model_type == ModelType.WHISPER:
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)
next_text: str = result.get('text')
# Update initial prompt between successive recording chunks
initial_prompt += next_text
logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
self.mutex.release()
except PortAudioError as exc:
self.error.emit(str(exc))
logging.exception('')
return
self.finished.emit()
@staticmethod
def get_device_sample_rate(device_id: Optional[int]) -> int:
"""Returns the sample rate to be used for recording. It uses the default sample rate
provided by Whisper if the microphone supports it, or else it uses the device's default
sample rate.
"""
whisper_sample_rate = whisper.audio.SAMPLE_RATE
try:
sounddevice.check_input_settings(
device=device_id, samplerate=whisper_sample_rate)
return whisper_sample_rate
except PortAudioError:
device_info = sounddevice.query_devices(device=device_id)
if isinstance(device_info, dict):
return int(device_info.get('default_samplerate', whisper_sample_rate))
return whisper_sample_rate
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
# Try to enqueue the next block. If the queue is already full, drop the block.
chunk: np.ndarray = in_data.ravel()
with self.mutex:
if self.queue.size < self.max_queue_size:
self.queue = np.append(self.queue, chunk)
@staticmethod
def amplitude(arr: np.ndarray):
return (abs(max(arr)) + abs(min(arr))) / 2
def stop_recording(self):
self.is_running = False
class OutputFormat(enum.Enum):
TXT = 'txt'
SRT = 'srt'

View file

@ -11,11 +11,11 @@ from PyQt6.QtCore import QThread
from pytestqt.qtbot import QtBot
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber,
Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
WhisperFileTranscriber,
get_default_output_file_path, to_timestamp,
whisper_cpp_params, write_output, TranscriptionOptions)
from buzz.recording_transcriber import RecordingTranscriber
from tests.mock_sounddevice import MockInputStream
from tests.model_loader import get_model_path