diff --git a/buzz/transcriber/whisper_cpp.py b/buzz/transcriber/whisper_cpp.py index 32fac7ae..cce3125a 100644 --- a/buzz/transcriber/whisper_cpp.py +++ b/buzz/transcriber/whisper_cpp.py @@ -15,6 +15,26 @@ if LOADED_WHISPER_CPP_BINARY: class WhisperCpp: def __init__(self, model: str) -> None: self.ctx = whisper_cpp.whisper_init_from_file(model.encode()) + self.segments: List[Segment] = [] + + def append_segment(self, txt: bytes, start: int, end: int): + if txt == b'': + return True + + # try-catch will guard against multi-byte utf-8 characters + # https://github.com/ggerganov/whisper.cpp/issues/1798 + try: + self.segments.append( + Segment( + start=start * 10, # centisecond to ms + end=end * 10, # centisecond to ms + text=txt.decode("utf-8"), + ) + ) + + return True + except UnicodeDecodeError: + return False def transcribe(self, audio: Union[np.ndarray, str], params: Any): if isinstance(audio, str): @@ -29,25 +49,50 @@ class WhisperCpp: if result != 0: raise Exception(f"Error from whisper.cpp: {result}") - segments: List[Segment] = [] + n_segments = whisper_cpp.whisper_full_n_segments(self.ctx) - n_segments = whisper_cpp.whisper_full_n_segments((self.ctx)) - for i in range(n_segments): - txt = whisper_cpp.whisper_full_get_segment_text((self.ctx), i) - t0 = whisper_cpp.whisper_full_get_segment_t0((self.ctx), i) - t1 = whisper_cpp.whisper_full_get_segment_t1((self.ctx), i) + if params.token_timestamps: + # Will process word timestamps + txt_buffer = b'' + txt_start = 0 + txt_end = 0 - segments.append( - Segment( - start=t0 * 10, # centisecond to ms - end=t1 * 10, # centisecond to ms - text=txt.decode("utf-8"), - ) - ) + for i in range(n_segments): + txt = whisper_cpp.whisper_full_get_segment_text(self.ctx, i) + start = whisper_cpp.whisper_full_get_segment_t0(self.ctx, i) + end = whisper_cpp.whisper_full_get_segment_t1(self.ctx, i) + + if txt.startswith(b' ') and self.append_segment(txt_buffer, txt_start, txt_end): + txt_buffer = txt + txt_start = start + txt_end = end + continue + + if txt.startswith(b', '): + txt_buffer += b',' + self.append_segment(txt_buffer, txt_start, txt_end) + txt_buffer = txt.lstrip(b',') + txt_start = start + txt_end = end + continue + + txt_buffer += txt + txt_end = end + + # Append the last segment + self.append_segment(txt_buffer, txt_start, txt_end) + + else: + for i in range(n_segments): + txt = whisper_cpp.whisper_full_get_segment_text(self.ctx, i) + start = whisper_cpp.whisper_full_get_segment_t0(self.ctx, i) + end = whisper_cpp.whisper_full_get_segment_t1(self.ctx, i) + + self.append_segment(txt, start, end) return { - "segments": segments, - "text": "".join([segment.text for segment in segments]), + "segments": self.segments, + "text": "".join([segment.text for segment in self.segments]), } def __del__(self): diff --git a/testdata/whisper-latvian.wav b/testdata/whisper-latvian.wav new file mode 100644 index 00000000..d62d5937 Binary files /dev/null and b/testdata/whisper-latvian.wav differ diff --git a/tests/audio.py b/tests/audio.py index 2c4a37c2..76a9b459 100644 --- a/tests/audio.py +++ b/tests/audio.py @@ -3,3 +3,7 @@ import os.path test_audio_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "../testdata/whisper-french.mp3") ) + +test_multibyte_utf8_audio_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../testdata/whisper-latvian.wav") +) \ No newline at end of file diff --git a/tests/transcriber/whisper_cpp_file_transcriber_test.py b/tests/transcriber/whisper_cpp_file_transcriber_test.py index 14345135..bd9a993e 100644 --- a/tests/transcriber/whisper_cpp_file_transcriber_test.py +++ b/tests/transcriber/whisper_cpp_file_transcriber_test.py @@ -13,7 +13,7 @@ from buzz.transcriber.transcriber import ( FileTranscriptionTask, ) from buzz.transcriber.whisper_cpp_file_transcriber import WhisperCppFileTranscriber -from tests.audio import test_audio_path +from tests.audio import test_audio_path, test_multibyte_utf8_audio_path from tests.model_loader import get_model_path @@ -25,7 +25,7 @@ class TestWhisperCppFileTranscriber: False, [Segment(0, 6560, "Bienvenue dans Passe-Relle. Un podcast pensé pour")], ), - (True, [Segment(30, 330, "Bien"), Segment(330, 740, "venue")]), + (True, [Segment(30, 740, "Bienvenue"), Segment(740, 1070, " dans")]), ], ) def test_transcribe( @@ -75,3 +75,63 @@ class TestWhisperCppFileTranscriber: assert expected_segment.start == segments[i].start assert expected_segment.end == segments[i].end assert expected_segment.text in segments[i].text + + @pytest.mark.parametrize( + "word_level_timings,expected_segments", + [ + ( + False, + [Segment(0, 7000, " Mani uzstrauts, laikabstākļi, tapēc uz jūru, es diezvajī braukša.")], + ), + (True, [Segment(380, 500, " Mani"), Segment(500, 1880, " uzstrauts,"), Segment(1880, 3920, " laikabstākļi")]), + ], + ) + # Problematic part is in "laikabstākļi" where "ļ" gets returned from whisper.cpp in two segments + # First segment has first byte b'\xc4' and the second has second byte b'\xbc'. + def test_transcribe_latvian( + self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment] + ): + file_transcription_options = FileTranscriptionOptions( + file_paths=[test_multibyte_utf8_audio_path] + ) + transcription_options = TranscriptionOptions( + language="lv", + task=Task.TRANSCRIBE, + word_level_timings=word_level_timings, + model=TranscriptionModel( + model_type=ModelType.WHISPER_CPP, + whisper_model_size=WhisperModelSize.TINY, + ), + ) + + model_path = get_model_path(transcription_options.model) + transcriber = WhisperCppFileTranscriber( + task=FileTranscriptionTask( + file_path=test_multibyte_utf8_audio_path, + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + model_path=model_path, + ) + ) + mock_progress = Mock(side_effect=lambda value: print("progress: ", value)) + mock_completed = Mock() + mock_error = Mock() + transcriber.progress.connect(mock_progress) + transcriber.completed.connect(mock_completed) + transcriber.error.connect(mock_error) + + with qtbot.wait_signal(transcriber.completed, timeout=10 * 60 * 1000): + transcriber.run() + + mock_error.assert_not_called() + + mock_progress.assert_called() + segments = [ + segment + for segment in mock_completed.call_args[0][0] + if len(segment.text) > 0 + ] + for i, expected_segment in enumerate(expected_segments): + assert expected_segment.start == segments[i].start + assert expected_segment.end == segments[i].end + assert expected_segment.text in segments[i].text