Adding support for word level timings in Whisper API and Whisper.cpp … (#1183)

This commit is contained in:
Raivis Dejus 2025-06-18 16:33:31 +03:00 committed by GitHub
commit 4b786595c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 108 additions and 41 deletions

View file

@ -13,6 +13,7 @@ from buzz.settings.settings import Settings
from buzz.model_loader import get_custom_api_whisper_model
from buzz.transcriber.file_transcriber import FileTranscriber
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment, Task
from buzz.transcriber.whisper_cpp import append_segment
class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
@ -28,6 +29,7 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
base_url=custom_openai_base_url if custom_openai_base_url else None
)
self.whisper_api_model = get_custom_api_whisper_model(custom_openai_base_url)
self.word_level_timings = self.transcription_task.transcription_options.word_level_timings
logging.debug("Will use whisper API on %s, %s",
custom_openai_base_url, self.whisper_api_model)
@ -136,6 +138,12 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
return segments
@staticmethod
def get_value(segment, key):
if hasattr(segment, key):
return getattr(segment, key)
return segment[key]
def get_segments_for_file(self, file: str, offset_ms: int = 0):
with open(file, "rb") as file:
options = {
@ -144,6 +152,10 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
"response_format": "verbose_json",
"prompt": self.transcription_task.transcription_options.initial_prompt,
}
if self.word_level_timings:
options["timestamp_granularities"] = ["word"]
transcript = (
self.openai_client.audio.transcriptions.create(
**options,
@ -153,14 +165,79 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
else self.openai_client.audio.translations.create(**options)
)
return [
Segment(
int(segment["start"] * 1000 + offset_ms),
int(segment["end"] * 1000 + offset_ms),
segment["text"],
)
for segment in transcript.model_extra["segments"]
]
segments = getattr(transcript, "segments", None)
words = getattr(transcript, "words", None)
if "words" is None and "words" in transcript.model_extra:
words = transcript.model_extra["words"]
if segments is None:
if "segments" in transcript.model_extra:
segments = transcript.model_extra["segments"]
else:
segments = [{"words": words}]
result_segments = []
if self.word_level_timings:
# Detect response from whisper.cpp API
first_segment = segments[0] if segments else None
is_whisper_cpp = (first_segment and hasattr(first_segment, "tokens")
and hasattr(first_segment, "avg_logprob") and hasattr(first_segment, "no_speech_prob"))
if is_whisper_cpp:
txt_buffer = b''
txt_start = 0
txt_end = 0
for segment in segments:
for word in self.get_value(segment, "words"):
txt = self.get_value(word, "word").encode("utf-8")
start = self.get_value(word, "start")
end = self.get_value(word, "end")
if txt.startswith(b' ') and append_segment(result_segments, txt_buffer, txt_start, txt_end):
txt_buffer = txt
txt_start = start
txt_end = end
continue
if txt.startswith(b', '):
txt_buffer += b','
append_segment(result_segments, txt_buffer, txt_start, txt_end)
txt_buffer = txt.lstrip(b',')
txt_start = start
txt_end = end
continue
txt_buffer += txt
txt_end = end
# Append the last segment
append_segment(result_segments, txt_buffer, txt_start, txt_end)
else:
for segment in segments:
for word in self.get_value(segment, "words"):
result_segments.append(
Segment(
int(self.get_value(word, "start") * 1000 + offset_ms),
int(self.get_value(word, "end") * 1000 + offset_ms),
self.get_value(word, "word"),
)
)
else:
result_segments = [
Segment(
int(self.get_value(segment, "start") * 1000 + offset_ms),
int(self.get_value(segment,"end") * 1000 + offset_ms),
self.get_value(segment,"text"),
)
for segment in segments
]
return result_segments
def stop(self):
pass

View file

@ -23,6 +23,24 @@ if platform.system() == "Darwin" and platform.machine() == "arm64":
except ImportError:
logging.exception("")
def append_segment(result, txt: bytes, start: int, end: int):
if txt == b'':
return True
# try-catch will guard against multi-byte utf-8 characters
# https://github.com/ggerganov/whisper.cpp/issues/1798
try:
result.append(
Segment(
start=start * 10, # centisecond to ms
end=end * 10, # centisecond to ms
text=txt.decode("utf-8"),
)
)
return True
except UnicodeDecodeError:
return False
class WhisperCpp:
def __init__(self, model: str) -> None:
@ -40,25 +58,6 @@ class WhisperCpp:
self.ctx = self.instance.init_from_file(model)
self.segments: List[Segment] = []
def append_segment(self, txt: bytes, start: int, end: int):
if txt == b'':
return True
# try-catch will guard against multi-byte utf-8 characters
# https://github.com/ggerganov/whisper.cpp/issues/1798
try:
self.segments.append(
Segment(
start=start * 10, # centisecond to ms
end=end * 10, # centisecond to ms
text=txt.decode("utf-8"),
)
)
return True
except UnicodeDecodeError:
return False
def transcribe(self, audio: Union[np.ndarray, str], params: Any):
self.segments = []
@ -87,7 +86,7 @@ class WhisperCpp:
start = self.instance.full_get_segment_t0(self.ctx, i)
end = self.instance.full_get_segment_t1(self.ctx, i)
if txt.startswith(b' ') and self.append_segment(txt_buffer, txt_start, txt_end):
if txt.startswith(b' ') and append_segment(self.segments, txt_buffer, txt_start, txt_end):
txt_buffer = txt
txt_start = start
txt_end = end
@ -95,7 +94,7 @@ class WhisperCpp:
if txt.startswith(b', '):
txt_buffer += b','
self.append_segment(txt_buffer, txt_start, txt_end)
append_segment(self.segments, txt_buffer, txt_start, txt_end)
txt_buffer = txt.lstrip(b',')
txt_start = start
txt_end = end
@ -105,7 +104,7 @@ class WhisperCpp:
txt_end = end
# Append the last segment
self.append_segment(txt_buffer, txt_start, txt_end)
append_segment(self.segments, txt_buffer, txt_start, txt_end)
else:
for i in range(n_segments):
@ -113,7 +112,7 @@ class WhisperCpp:
start = self.instance.full_get_segment_t0(self.ctx, i)
end = self.instance.full_get_segment_t1(self.ctx, i)
self.append_segment(txt, start, end)
append_segment(self.segments, txt, start, end)
return {
"segments": self.segments,

View file

@ -80,13 +80,10 @@ class FileTranscriptionFormWidget(QWidget):
layout.addLayout(file_transcription_layout)
self.setLayout(layout)
self.reset_word_level_timings()
def on_transcription_options_changed(
self, transcription_options: TranscriptionOptions
):
self.transcription_options = transcription_options
self.reset_word_level_timings()
self.transcription_options_changed.emit(
(self.transcription_options, self.file_transcription_options)
)
@ -125,9 +122,3 @@ class FileTranscriptionFormWidget(QWidget):
)
return on_checkbox_state_changed
def reset_word_level_timings(self):
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type
== ModelType.OPEN_AI_WHISPER_API
)

View file

@ -11,7 +11,7 @@ Open the Preferences window from the Menu bar, or click `Ctrl/Cmd + ,`.
**API Key** - key to authenticate your requests to OpenAI API. To get API key from OpenAI see [this article](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key).
**Base URL** - By default all requests are sent to API provided by OpenAI company. Their API URL is `https://api.openai.com/v1/`. Compatible APIs are also provided by other companies. List of available API URLs you can find on [discussion page](https://github.com/chidiwilliams/buzz/discussions/827)
**Base URL** - By default all requests are sent to API provided by OpenAI company. Their API URL is `https://api.openai.com/v1/`. Compatible APIs are also provided by other companies. List of available API URLs and services to run yourself are available on [discussion page](https://github.com/chidiwilliams/buzz/discussions/827)
### Default export file name