Fix recording transcriber (#286)

This commit is contained in:
Chidi Williams 2023-01-02 15:42:15 +00:00 committed by GitHub
parent 611e623a3a
commit 380e975870
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 57 deletions

View file

@ -157,11 +157,11 @@ class RecordButton(QPushButton):
self.setDefault(True)
self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
def set_to_record(self):
def set_stopped(self):
self.setText('Record')
self.setDefault(True)
def set_to_stop(self):
def set_recording(self):
self.setText('Stop')
self.setDefault(False)
@ -529,11 +529,10 @@ class RecordingTranscriberWidget(QDialog):
if self.current_status == self.RecordingStatus.STOPPED:
self.start_recording()
self.current_status = self.RecordingStatus.RECORDING
self.record_button.set_to_stop()
self.record_button.set_recording()
else: # RecordingStatus.RECORDING
self.stop_recording()
self.record_button.set_to_record()
self.current_status = self.RecordingStatus.STOPPED
self.set_recording_status_stopped()
def start_recording(self):
self.record_button.setDisabled(True)
@ -567,6 +566,10 @@ class RecordingTranscriberWidget(QDialog):
self.transcriber.finished.connect(self.transcription_thread.quit)
self.transcriber.finished.connect(self.transcriber.deleteLater)
self.transcriber.error.connect(self.on_transcriber_error)
self.transcriber.error.connect(self.transcription_thread.quit)
self.transcriber.error.connect(self.transcriber.deleteLater)
self.transcription_thread.start()
def on_download_model_progress(self, progress: Tuple[float, float]):
@ -580,11 +583,15 @@ class RecordingTranscriberWidget(QDialog):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
def set_recording_status_stopped(self):
self.record_button.set_stopped()
self.current_status = self.RecordingStatus.STOPPED
def on_download_model_error(self, error: str):
self.reset_model_download()
show_model_download_error_dialog(self, error)
self.stop_recording()
self.record_button.set_to_stop()
self.set_recording_status_stopped()
self.record_button.setDisabled(False)
def on_next_transcription(self, text: str):
@ -603,13 +610,18 @@ class RecordingTranscriberWidget(QDialog):
self.record_button.setDisabled(True)
def on_transcriber_finished(self):
self.record_button.setEnabled(True)
self.reset_record_button()
def on_transcriber_error(self, error: str):
self.reset_record_button()
self.set_recording_status_stopped()
QMessageBox.critical(self, '', f'An error occurred while starting a new recording: {error}. Please check your audio devices or check the application logs for more information.')
def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
self.model_loader.stop()
self.reset_model_download()
self.record_button.set_to_stop()
self.set_recording_status_stopped()
self.record_button.setDisabled(False)
def reset_model_download(self):
@ -620,11 +632,14 @@ class RecordingTranscriberWidget(QDialog):
def reset_recording_controls(self):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.record_button.setDisabled(False)
self.reset_record_button()
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None
def reset_record_button(self):
self.record_button.setEnabled(True)
def on_recording_amplitude_changed(self, amplitude: float):
self.audio_meter_widget.update_amplitude(amplitude)

View file

@ -1,5 +1,6 @@
from typing import Optional
import logging
import numpy as np
import sounddevice
from PyQt6.QtCore import QObject, pyqtSignal
@ -16,13 +17,17 @@ class RecordingAmplitudeListener(QObject):
self.input_device_index = input_device_index
def start_recording(self):
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
channels=1, callback=self.stream_callback)
self.stream.start()
try:
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
channels=1, callback=self.stream_callback)
self.stream.start()
except sounddevice.PortAudioError:
logging.exception('')
def stop_recording(self):
self.stream.stop()
self.stream.close()
if self.stream is not None:
self.stream.stop()
self.stream.close()
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
chunk = in_data.ravel()

View file

@ -92,6 +92,7 @@ class FileTranscriptionTask:
class RecordingTranscriber(QObject):
transcription = pyqtSignal(str)
finished = pyqtSignal()
error = pyqtSignal(str)
is_running = False
MAX_QUEUE_SIZE = 10
@ -123,52 +124,57 @@ class RecordingTranscriber(QObject):
self.transcription_options, model_path, self.sample_rate, self.input_device_index)
self.is_running = True
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
self.mutex.release()
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
self.mutex.release()
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()
if self.transcription_options.model.model_type == ModelType.WHISPER:
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
if self.transcription_options.model.model_type == ModelType.WHISPER:
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)
next_text: str = result.get('text')
# Update initial prompt between successive recording chunks
initial_prompt += next_text
logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)
next_text: str = result.get('text')
# Update initial prompt between successive recording chunks
initial_prompt += next_text
logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
self.mutex.release()
self.mutex.release()
except PortAudioError as exc:
self.error.emit(str(exc))
logging.exception('')
return
self.finished.emit()