Add segments debug logging (#294)

This commit is contained in:
Chidi Williams 2023-01-03 20:21:59 +00:00 committed by GitHub
parent 40b0236f8c
commit 614b395962
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 19 deletions

View file

@ -126,8 +126,8 @@ class RecordingTranscriber(QObject):
self.is_running = True
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
@ -136,7 +136,7 @@ class RecordingTranscriber(QObject):
self.mutex.release()
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()
if self.transcription_options.model.model_type == ModelType.WHISPER:
@ -157,9 +157,9 @@ class RecordingTranscriber(QObject):
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)
next_text: str = result.get('text')
@ -167,7 +167,7 @@ class RecordingTranscriber(QObject):
initial_prompt += next_text
logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
self.mutex.release()
@ -367,19 +367,16 @@ class WhisperFileTranscriber(QObject):
self.current_process.join()
logging.debug(
'whisper process completed with code = %s, time taken = %s',
self.current_process.exitcode, datetime.datetime.now() - time_started)
if self.current_process.exitcode != 0:
send_pipe.close()
self.read_line_thread.join()
# TODO: fix error handling when process crashes
if self.current_process.exitcode != 0 and self.current_process.exitcode is not None:
self.completed.emit([])
logging.debug(
'whisper process completed with code = %s, time taken = %s, number of segments = %s',
self.current_process.exitcode, datetime.datetime.now() - time_started, len(self.segments))
self.completed.emit(self.segments)
self.running = False
def stop(self):
@ -403,9 +400,7 @@ class WhisperFileTranscriber(QObject):
end=segment.get('end'),
text=segment.get('text'),
) for segment in segments_dict]
self.current_process.join()
# TODO: move this back to the parent thread
self.completed.emit(segments)
self.segments = segments
else:
try:
progress = int(line.split('|')[0].strip().strip('%'))

View file

@ -1,3 +1,4 @@
import logging
import os
import pathlib
import platform
@ -7,7 +8,7 @@ from typing import List
from unittest.mock import Mock, patch
import pytest
from PyQt6.QtCore import QThread
from PyQt6.QtCore import QThread, QCoreApplication
from pytestqt.qtbot import QtBot
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel, ModelLoader
@ -135,7 +136,8 @@ class TestWhisperFileTranscriber:
file_path='testdata/whisper-french.mp3', model_path=model_path))
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
with qtbot.wait_signal(transcriber.progress, timeout=10 * 6000), qtbot.wait_signal(transcriber.completed,
timeout=10 * 6000):
transcriber.run()
if check_progress:
@ -150,6 +152,7 @@ class TestWhisperFileTranscriber:
mock_completed.assert_called()
segments = mock_completed.call_args[0][0]
assert len(segments) >= len(expected_segments)
for (i, expected_segment) in enumerate(expected_segments):
assert segments[i] == expected_segment