mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-29 13:10:26 +02:00
Fix recording transcriber (#286)
This commit is contained in:
parent
611e623a3a
commit
380e975870
33
buzz/gui.py
33
buzz/gui.py
|
@ -157,11 +157,11 @@ class RecordButton(QPushButton):
|
|||
self.setDefault(True)
|
||||
self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
|
||||
|
||||
def set_to_record(self):
|
||||
def set_stopped(self):
|
||||
self.setText('Record')
|
||||
self.setDefault(True)
|
||||
|
||||
def set_to_stop(self):
|
||||
def set_recording(self):
|
||||
self.setText('Stop')
|
||||
self.setDefault(False)
|
||||
|
||||
|
@ -529,11 +529,10 @@ class RecordingTranscriberWidget(QDialog):
|
|||
if self.current_status == self.RecordingStatus.STOPPED:
|
||||
self.start_recording()
|
||||
self.current_status = self.RecordingStatus.RECORDING
|
||||
self.record_button.set_to_stop()
|
||||
self.record_button.set_recording()
|
||||
else: # RecordingStatus.RECORDING
|
||||
self.stop_recording()
|
||||
self.record_button.set_to_record()
|
||||
self.current_status = self.RecordingStatus.STOPPED
|
||||
self.set_recording_status_stopped()
|
||||
|
||||
def start_recording(self):
|
||||
self.record_button.setDisabled(True)
|
||||
|
@ -567,6 +566,10 @@ class RecordingTranscriberWidget(QDialog):
|
|||
self.transcriber.finished.connect(self.transcription_thread.quit)
|
||||
self.transcriber.finished.connect(self.transcriber.deleteLater)
|
||||
|
||||
self.transcriber.error.connect(self.on_transcriber_error)
|
||||
self.transcriber.error.connect(self.transcription_thread.quit)
|
||||
self.transcriber.error.connect(self.transcriber.deleteLater)
|
||||
|
||||
self.transcription_thread.start()
|
||||
|
||||
def on_download_model_progress(self, progress: Tuple[float, float]):
|
||||
|
@ -580,11 +583,15 @@ class RecordingTranscriberWidget(QDialog):
|
|||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
|
||||
|
||||
def set_recording_status_stopped(self):
|
||||
self.record_button.set_stopped()
|
||||
self.current_status = self.RecordingStatus.STOPPED
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
self.reset_model_download()
|
||||
show_model_download_error_dialog(self, error)
|
||||
self.stop_recording()
|
||||
self.record_button.set_to_stop()
|
||||
self.set_recording_status_stopped()
|
||||
self.record_button.setDisabled(False)
|
||||
|
||||
def on_next_transcription(self, text: str):
|
||||
|
@ -603,13 +610,18 @@ class RecordingTranscriberWidget(QDialog):
|
|||
self.record_button.setDisabled(True)
|
||||
|
||||
def on_transcriber_finished(self):
|
||||
self.record_button.setEnabled(True)
|
||||
self.reset_record_button()
|
||||
|
||||
def on_transcriber_error(self, error: str):
|
||||
self.reset_record_button()
|
||||
self.set_recording_status_stopped()
|
||||
QMessageBox.critical(self, '', f'An error occurred while starting a new recording: {error}. Please check your audio devices or check the application logs for more information.')
|
||||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.stop()
|
||||
self.reset_model_download()
|
||||
self.record_button.set_to_stop()
|
||||
self.set_recording_status_stopped()
|
||||
self.record_button.setDisabled(False)
|
||||
|
||||
def reset_model_download(self):
|
||||
|
@ -620,11 +632,14 @@ class RecordingTranscriberWidget(QDialog):
|
|||
def reset_recording_controls(self):
|
||||
# Clear text box placeholder because the first chunk takes a while to process
|
||||
self.text_box.setPlaceholderText('')
|
||||
self.record_button.setDisabled(False)
|
||||
self.reset_record_button()
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.close()
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
def reset_record_button(self):
|
||||
self.record_button.setEnabled(True)
|
||||
|
||||
def on_recording_amplitude_changed(self, amplitude: float):
|
||||
self.audio_meter_widget.update_amplitude(amplitude)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import sounddevice
|
||||
from PyQt6.QtCore import QObject, pyqtSignal
|
||||
|
@ -16,13 +17,17 @@ class RecordingAmplitudeListener(QObject):
|
|||
self.input_device_index = input_device_index
|
||||
|
||||
def start_recording(self):
|
||||
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
|
||||
channels=1, callback=self.stream_callback)
|
||||
self.stream.start()
|
||||
try:
|
||||
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
|
||||
channels=1, callback=self.stream_callback)
|
||||
self.stream.start()
|
||||
except sounddevice.PortAudioError:
|
||||
logging.exception('')
|
||||
|
||||
def stop_recording(self):
|
||||
self.stream.stop()
|
||||
self.stream.close()
|
||||
if self.stream is not None:
|
||||
self.stream.stop()
|
||||
self.stream.close()
|
||||
|
||||
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
|
||||
chunk = in_data.ravel()
|
||||
|
|
|
@ -92,6 +92,7 @@ class FileTranscriptionTask:
|
|||
class RecordingTranscriber(QObject):
|
||||
transcription = pyqtSignal(str)
|
||||
finished = pyqtSignal()
|
||||
error = pyqtSignal(str)
|
||||
is_running = False
|
||||
MAX_QUEUE_SIZE = 10
|
||||
|
||||
|
@ -123,52 +124,57 @@ class RecordingTranscriber(QObject):
|
|||
self.transcription_options, model_path, self.sample_rate, self.input_device_index)
|
||||
|
||||
self.is_running = True
|
||||
with sounddevice.InputStream(samplerate=self.sample_rate,
|
||||
device=self.input_device_index, dtype="float32",
|
||||
channels=1, callback=self.stream_callback):
|
||||
while self.is_running:
|
||||
self.mutex.acquire()
|
||||
if self.queue.size >= self.n_batch_samples:
|
||||
samples = self.queue[:self.n_batch_samples]
|
||||
self.queue = self.queue[self.n_batch_samples:]
|
||||
self.mutex.release()
|
||||
try:
|
||||
with sounddevice.InputStream(samplerate=self.sample_rate,
|
||||
device=self.input_device_index, dtype="float32",
|
||||
channels=1, callback=self.stream_callback):
|
||||
while self.is_running:
|
||||
self.mutex.acquire()
|
||||
if self.queue.size >= self.n_batch_samples:
|
||||
samples = self.queue[:self.n_batch_samples]
|
||||
self.queue = self.queue[self.n_batch_samples:]
|
||||
self.mutex.release()
|
||||
|
||||
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
|
||||
samples.size, self.queue.size, self.amplitude(samples))
|
||||
time_started = datetime.datetime.now()
|
||||
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
|
||||
samples.size, self.queue.size, self.amplitude(samples))
|
||||
time_started = datetime.datetime.now()
|
||||
|
||||
if self.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
assert isinstance(model, whisper.Whisper)
|
||||
result = model.transcribe(
|
||||
audio=samples, language=self.transcription_options.language,
|
||||
task=self.transcription_options.task.value,
|
||||
initial_prompt=initial_prompt,
|
||||
temperature=self.transcription_options.temperature)
|
||||
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
|
||||
assert isinstance(model, WhisperCpp)
|
||||
result = model.transcribe(
|
||||
audio=samples,
|
||||
params=whisper_cpp_params(
|
||||
language=self.transcription_options.language
|
||||
if self.transcription_options.language is not None else 'en',
|
||||
task=self.transcription_options.task.value, word_level_timings=False))
|
||||
if self.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
assert isinstance(model, whisper.Whisper)
|
||||
result = model.transcribe(
|
||||
audio=samples, language=self.transcription_options.language,
|
||||
task=self.transcription_options.task.value,
|
||||
initial_prompt=initial_prompt,
|
||||
temperature=self.transcription_options.temperature)
|
||||
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
|
||||
assert isinstance(model, WhisperCpp)
|
||||
result = model.transcribe(
|
||||
audio=samples,
|
||||
params=whisper_cpp_params(
|
||||
language=self.transcription_options.language
|
||||
if self.transcription_options.language is not None else 'en',
|
||||
task=self.transcription_options.task.value, word_level_timings=False))
|
||||
else:
|
||||
assert isinstance(model, TransformersWhisper)
|
||||
result = model.transcribe(audio=samples,
|
||||
language=self.transcription_options.language
|
||||
if self.transcription_options.language is not None else 'en',
|
||||
task=self.transcription_options.task.value)
|
||||
|
||||
next_text: str = result.get('text')
|
||||
|
||||
# Update initial prompt between successive recording chunks
|
||||
initial_prompt += next_text
|
||||
|
||||
logging.debug('Received next result, length = %s, time taken = %s',
|
||||
len(next_text), datetime.datetime.now() - time_started)
|
||||
self.transcription.emit(next_text)
|
||||
else:
|
||||
assert isinstance(model, TransformersWhisper)
|
||||
result = model.transcribe(audio=samples,
|
||||
language=self.transcription_options.language
|
||||
if self.transcription_options.language is not None else 'en',
|
||||
task=self.transcription_options.task.value)
|
||||
|
||||
next_text: str = result.get('text')
|
||||
|
||||
# Update initial prompt between successive recording chunks
|
||||
initial_prompt += next_text
|
||||
|
||||
logging.debug('Received next result, length = %s, time taken = %s',
|
||||
len(next_text), datetime.datetime.now() - time_started)
|
||||
self.transcription.emit(next_text)
|
||||
else:
|
||||
self.mutex.release()
|
||||
self.mutex.release()
|
||||
except PortAudioError as exc:
|
||||
self.error.emit(str(exc))
|
||||
logging.exception('')
|
||||
return
|
||||
|
||||
self.finished.emit()
|
||||
|
||||
|
|
Loading…
Reference in a new issue