mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-09 11:42:12 +02:00
Adding fix for multi-byte segments in whisper.cpp (#734)
Co-authored-by: Chidi Williams <williamschidi1@gmail.com>
This commit is contained in:
parent
ca49b8e865
commit
38f5d26672
|
@ -15,6 +15,26 @@ if LOADED_WHISPER_CPP_BINARY:
|
|||
class WhisperCpp:
|
||||
def __init__(self, model: str) -> None:
|
||||
self.ctx = whisper_cpp.whisper_init_from_file(model.encode())
|
||||
self.segments: List[Segment] = []
|
||||
|
||||
def append_segment(self, txt: bytes, start: int, end: int):
|
||||
if txt == b'':
|
||||
return True
|
||||
|
||||
# try-catch will guard against multi-byte utf-8 characters
|
||||
# https://github.com/ggerganov/whisper.cpp/issues/1798
|
||||
try:
|
||||
self.segments.append(
|
||||
Segment(
|
||||
start=start * 10, # centisecond to ms
|
||||
end=end * 10, # centisecond to ms
|
||||
text=txt.decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
|
||||
def transcribe(self, audio: Union[np.ndarray, str], params: Any):
|
||||
if isinstance(audio, str):
|
||||
|
@ -29,25 +49,50 @@ class WhisperCpp:
|
|||
if result != 0:
|
||||
raise Exception(f"Error from whisper.cpp: {result}")
|
||||
|
||||
segments: List[Segment] = []
|
||||
n_segments = whisper_cpp.whisper_full_n_segments(self.ctx)
|
||||
|
||||
n_segments = whisper_cpp.whisper_full_n_segments((self.ctx))
|
||||
for i in range(n_segments):
|
||||
txt = whisper_cpp.whisper_full_get_segment_text((self.ctx), i)
|
||||
t0 = whisper_cpp.whisper_full_get_segment_t0((self.ctx), i)
|
||||
t1 = whisper_cpp.whisper_full_get_segment_t1((self.ctx), i)
|
||||
if params.token_timestamps:
|
||||
# Will process word timestamps
|
||||
txt_buffer = b''
|
||||
txt_start = 0
|
||||
txt_end = 0
|
||||
|
||||
segments.append(
|
||||
Segment(
|
||||
start=t0 * 10, # centisecond to ms
|
||||
end=t1 * 10, # centisecond to ms
|
||||
text=txt.decode("utf-8"),
|
||||
)
|
||||
)
|
||||
for i in range(n_segments):
|
||||
txt = whisper_cpp.whisper_full_get_segment_text(self.ctx, i)
|
||||
start = whisper_cpp.whisper_full_get_segment_t0(self.ctx, i)
|
||||
end = whisper_cpp.whisper_full_get_segment_t1(self.ctx, i)
|
||||
|
||||
if txt.startswith(b' ') and self.append_segment(txt_buffer, txt_start, txt_end):
|
||||
txt_buffer = txt
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
continue
|
||||
|
||||
if txt.startswith(b', '):
|
||||
txt_buffer += b','
|
||||
self.append_segment(txt_buffer, txt_start, txt_end)
|
||||
txt_buffer = txt.lstrip(b',')
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
continue
|
||||
|
||||
txt_buffer += txt
|
||||
txt_end = end
|
||||
|
||||
# Append the last segment
|
||||
self.append_segment(txt_buffer, txt_start, txt_end)
|
||||
|
||||
else:
|
||||
for i in range(n_segments):
|
||||
txt = whisper_cpp.whisper_full_get_segment_text(self.ctx, i)
|
||||
start = whisper_cpp.whisper_full_get_segment_t0(self.ctx, i)
|
||||
end = whisper_cpp.whisper_full_get_segment_t1(self.ctx, i)
|
||||
|
||||
self.append_segment(txt, start, end)
|
||||
|
||||
return {
|
||||
"segments": segments,
|
||||
"text": "".join([segment.text for segment in segments]),
|
||||
"segments": self.segments,
|
||||
"text": "".join([segment.text for segment in self.segments]),
|
||||
}
|
||||
|
||||
def __del__(self):
|
||||
|
|
BIN
testdata/whisper-latvian.wav
vendored
Normal file
BIN
testdata/whisper-latvian.wav
vendored
Normal file
Binary file not shown.
|
@ -3,3 +3,7 @@ import os.path
|
|||
test_audio_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../testdata/whisper-french.mp3")
|
||||
)
|
||||
|
||||
test_multibyte_utf8_audio_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../testdata/whisper-latvian.wav")
|
||||
)
|
|
@ -13,7 +13,7 @@ from buzz.transcriber.transcriber import (
|
|||
FileTranscriptionTask,
|
||||
)
|
||||
from buzz.transcriber.whisper_cpp_file_transcriber import WhisperCppFileTranscriber
|
||||
from tests.audio import test_audio_path
|
||||
from tests.audio import test_audio_path, test_multibyte_utf8_audio_path
|
||||
from tests.model_loader import get_model_path
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ class TestWhisperCppFileTranscriber:
|
|||
False,
|
||||
[Segment(0, 6560, "Bienvenue dans Passe-Relle. Un podcast pensé pour")],
|
||||
),
|
||||
(True, [Segment(30, 330, "Bien"), Segment(330, 740, "venue")]),
|
||||
(True, [Segment(30, 740, "Bienvenue"), Segment(740, 1070, " dans")]),
|
||||
],
|
||||
)
|
||||
def test_transcribe(
|
||||
|
@ -75,3 +75,63 @@ class TestWhisperCppFileTranscriber:
|
|||
assert expected_segment.start == segments[i].start
|
||||
assert expected_segment.end == segments[i].end
|
||||
assert expected_segment.text in segments[i].text
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"word_level_timings,expected_segments",
|
||||
[
|
||||
(
|
||||
False,
|
||||
[Segment(0, 7000, " Mani uzstrauts, laikabstākļi, tapēc uz jūru, es diezvajī braukša.")],
|
||||
),
|
||||
(True, [Segment(380, 500, " Mani"), Segment(500, 1880, " uzstrauts,"), Segment(1880, 3920, " laikabstākļi")]),
|
||||
],
|
||||
)
|
||||
# Problematic part is in "laikabstākļi" where "ļ" gets returned from whisper.cpp in two segments
|
||||
# First segment has first byte b'\xc4' and the second has second byte b'\xbc'.
|
||||
def test_transcribe_latvian(
|
||||
self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]
|
||||
):
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=[test_multibyte_utf8_audio_path]
|
||||
)
|
||||
transcription_options = TranscriptionOptions(
|
||||
language="lv",
|
||||
task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings,
|
||||
model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
)
|
||||
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
transcriber = WhisperCppFileTranscriber(
|
||||
task=FileTranscriptionTask(
|
||||
file_path=test_multibyte_utf8_audio_path,
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
model_path=model_path,
|
||||
)
|
||||
)
|
||||
mock_progress = Mock(side_effect=lambda value: print("progress: ", value))
|
||||
mock_completed = Mock()
|
||||
mock_error = Mock()
|
||||
transcriber.progress.connect(mock_progress)
|
||||
transcriber.completed.connect(mock_completed)
|
||||
transcriber.error.connect(mock_error)
|
||||
|
||||
with qtbot.wait_signal(transcriber.completed, timeout=10 * 60 * 1000):
|
||||
transcriber.run()
|
||||
|
||||
mock_error.assert_not_called()
|
||||
|
||||
mock_progress.assert_called()
|
||||
segments = [
|
||||
segment
|
||||
for segment in mock_completed.call_args[0][0]
|
||||
if len(segment.text) > 0
|
||||
]
|
||||
for i, expected_segment in enumerate(expected_segments):
|
||||
assert expected_segment.start == segments[i].start
|
||||
assert expected_segment.end == segments[i].end
|
||||
assert expected_segment.text in segments[i].text
|
||||
|
|
Loading…
Reference in a new issue