Adding fix for multi-byte segments in whisper.cpp (#734)

Co-authored-by: Chidi Williams <williamschidi1@gmail.com>
This commit is contained in:
Raivis Dejus 2024-05-15 02:29:43 +03:00 committed by GitHub
parent ca49b8e865
commit 38f5d26672
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 126 additions and 17 deletions

View file

@ -15,6 +15,26 @@ if LOADED_WHISPER_CPP_BINARY:
class WhisperCpp:
def __init__(self, model: str) -> None:
self.ctx = whisper_cpp.whisper_init_from_file(model.encode())
self.segments: List[Segment] = []
def append_segment(self, txt: bytes, start: int, end: int):
if txt == b'':
return True
# try-catch will guard against multi-byte utf-8 characters
# https://github.com/ggerganov/whisper.cpp/issues/1798
try:
self.segments.append(
Segment(
start=start * 10, # centisecond to ms
end=end * 10, # centisecond to ms
text=txt.decode("utf-8"),
)
)
return True
except UnicodeDecodeError:
return False
def transcribe(self, audio: Union[np.ndarray, str], params: Any):
if isinstance(audio, str):
@ -29,25 +49,50 @@ class WhisperCpp:
if result != 0:
raise Exception(f"Error from whisper.cpp: {result}")
segments: List[Segment] = []
n_segments = whisper_cpp.whisper_full_n_segments(self.ctx)
n_segments = whisper_cpp.whisper_full_n_segments((self.ctx))
for i in range(n_segments):
txt = whisper_cpp.whisper_full_get_segment_text((self.ctx), i)
t0 = whisper_cpp.whisper_full_get_segment_t0((self.ctx), i)
t1 = whisper_cpp.whisper_full_get_segment_t1((self.ctx), i)
if params.token_timestamps:
# Will process word timestamps
txt_buffer = b''
txt_start = 0
txt_end = 0
segments.append(
Segment(
start=t0 * 10, # centisecond to ms
end=t1 * 10, # centisecond to ms
text=txt.decode("utf-8"),
)
)
for i in range(n_segments):
txt = whisper_cpp.whisper_full_get_segment_text(self.ctx, i)
start = whisper_cpp.whisper_full_get_segment_t0(self.ctx, i)
end = whisper_cpp.whisper_full_get_segment_t1(self.ctx, i)
if txt.startswith(b' ') and self.append_segment(txt_buffer, txt_start, txt_end):
txt_buffer = txt
txt_start = start
txt_end = end
continue
if txt.startswith(b', '):
txt_buffer += b','
self.append_segment(txt_buffer, txt_start, txt_end)
txt_buffer = txt.lstrip(b',')
txt_start = start
txt_end = end
continue
txt_buffer += txt
txt_end = end
# Append the last segment
self.append_segment(txt_buffer, txt_start, txt_end)
else:
for i in range(n_segments):
txt = whisper_cpp.whisper_full_get_segment_text(self.ctx, i)
start = whisper_cpp.whisper_full_get_segment_t0(self.ctx, i)
end = whisper_cpp.whisper_full_get_segment_t1(self.ctx, i)
self.append_segment(txt, start, end)
return {
"segments": segments,
"text": "".join([segment.text for segment in segments]),
"segments": self.segments,
"text": "".join([segment.text for segment in self.segments]),
}
def __del__(self):

BIN
testdata/whisper-latvian.wav vendored Normal file

Binary file not shown.

View file

@ -3,3 +3,7 @@ import os.path
test_audio_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../testdata/whisper-french.mp3")
)
test_multibyte_utf8_audio_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../testdata/whisper-latvian.wav")
)

View file

@ -13,7 +13,7 @@ from buzz.transcriber.transcriber import (
FileTranscriptionTask,
)
from buzz.transcriber.whisper_cpp_file_transcriber import WhisperCppFileTranscriber
from tests.audio import test_audio_path
from tests.audio import test_audio_path, test_multibyte_utf8_audio_path
from tests.model_loader import get_model_path
@ -25,7 +25,7 @@ class TestWhisperCppFileTranscriber:
False,
[Segment(0, 6560, "Bienvenue dans Passe-Relle. Un podcast pensé pour")],
),
(True, [Segment(30, 330, "Bien"), Segment(330, 740, "venue")]),
(True, [Segment(30, 740, "Bienvenue"), Segment(740, 1070, " dans")]),
],
)
def test_transcribe(
@ -75,3 +75,63 @@ class TestWhisperCppFileTranscriber:
assert expected_segment.start == segments[i].start
assert expected_segment.end == segments[i].end
assert expected_segment.text in segments[i].text
@pytest.mark.parametrize(
"word_level_timings,expected_segments",
[
(
False,
[Segment(0, 7000, " Mani uzstrauts, laikabstākļi, tapēc uz jūru, es diezvajī braukša.")],
),
(True, [Segment(380, 500, " Mani"), Segment(500, 1880, " uzstrauts,"), Segment(1880, 3920, " laikabstākļi")]),
],
)
# Problematic part is in "laikabstākļi" where "ļ" gets returned from whisper.cpp in two segments
# First segment has first byte b'\xc4' and the second has second byte b'\xbc'.
def test_transcribe_latvian(
self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]
):
file_transcription_options = FileTranscriptionOptions(
file_paths=[test_multibyte_utf8_audio_path]
)
transcription_options = TranscriptionOptions(
language="lv",
task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=TranscriptionModel(
model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY,
),
)
model_path = get_model_path(transcription_options.model)
transcriber = WhisperCppFileTranscriber(
task=FileTranscriptionTask(
file_path=test_multibyte_utf8_audio_path,
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
model_path=model_path,
)
)
mock_progress = Mock(side_effect=lambda value: print("progress: ", value))
mock_completed = Mock()
mock_error = Mock()
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
transcriber.error.connect(mock_error)
with qtbot.wait_signal(transcriber.completed, timeout=10 * 60 * 1000):
transcriber.run()
mock_error.assert_not_called()
mock_progress.assert_called()
segments = [
segment
for segment in mock_completed.call_args[0][0]
if len(segment.text) > 0
]
for i, expected_segment in enumerate(expected_segments):
assert expected_segment.start == segments[i].start
assert expected_segment.end == segments[i].end
assert expected_segment.text in segments[i].text