mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 22:55:46 +01:00
Adding support for word level timings in Whisper API and Whisper.cpp … (#1183)
This commit is contained in:
parent
eb58067145
commit
4b786595c3
4 changed files with 108 additions and 41 deletions
|
|
@ -13,6 +13,7 @@ from buzz.settings.settings import Settings
|
|||
from buzz.model_loader import get_custom_api_whisper_model
|
||||
from buzz.transcriber.file_transcriber import FileTranscriber
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment, Task
|
||||
from buzz.transcriber.whisper_cpp import append_segment
|
||||
|
||||
|
||||
class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
||||
|
|
@ -28,6 +29,7 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
base_url=custom_openai_base_url if custom_openai_base_url else None
|
||||
)
|
||||
self.whisper_api_model = get_custom_api_whisper_model(custom_openai_base_url)
|
||||
self.word_level_timings = self.transcription_task.transcription_options.word_level_timings
|
||||
logging.debug("Will use whisper API on %s, %s",
|
||||
custom_openai_base_url, self.whisper_api_model)
|
||||
|
||||
|
|
@ -136,6 +138,12 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
|
||||
return segments
|
||||
|
||||
@staticmethod
|
||||
def get_value(segment, key):
|
||||
if hasattr(segment, key):
|
||||
return getattr(segment, key)
|
||||
return segment[key]
|
||||
|
||||
def get_segments_for_file(self, file: str, offset_ms: int = 0):
|
||||
with open(file, "rb") as file:
|
||||
options = {
|
||||
|
|
@ -144,6 +152,10 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
"response_format": "verbose_json",
|
||||
"prompt": self.transcription_task.transcription_options.initial_prompt,
|
||||
}
|
||||
|
||||
if self.word_level_timings:
|
||||
options["timestamp_granularities"] = ["word"]
|
||||
|
||||
transcript = (
|
||||
self.openai_client.audio.transcriptions.create(
|
||||
**options,
|
||||
|
|
@ -153,14 +165,79 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
else self.openai_client.audio.translations.create(**options)
|
||||
)
|
||||
|
||||
return [
|
||||
Segment(
|
||||
int(segment["start"] * 1000 + offset_ms),
|
||||
int(segment["end"] * 1000 + offset_ms),
|
||||
segment["text"],
|
||||
)
|
||||
for segment in transcript.model_extra["segments"]
|
||||
]
|
||||
segments = getattr(transcript, "segments", None)
|
||||
|
||||
words = getattr(transcript, "words", None)
|
||||
if "words" is None and "words" in transcript.model_extra:
|
||||
words = transcript.model_extra["words"]
|
||||
|
||||
if segments is None:
|
||||
if "segments" in transcript.model_extra:
|
||||
segments = transcript.model_extra["segments"]
|
||||
else:
|
||||
segments = [{"words": words}]
|
||||
|
||||
result_segments = []
|
||||
if self.word_level_timings:
|
||||
|
||||
# Detect response from whisper.cpp API
|
||||
first_segment = segments[0] if segments else None
|
||||
is_whisper_cpp = (first_segment and hasattr(first_segment, "tokens")
|
||||
and hasattr(first_segment, "avg_logprob") and hasattr(first_segment, "no_speech_prob"))
|
||||
|
||||
if is_whisper_cpp:
|
||||
txt_buffer = b''
|
||||
txt_start = 0
|
||||
txt_end = 0
|
||||
|
||||
for segment in segments:
|
||||
for word in self.get_value(segment, "words"):
|
||||
|
||||
txt = self.get_value(word, "word").encode("utf-8")
|
||||
start = self.get_value(word, "start")
|
||||
end = self.get_value(word, "end")
|
||||
|
||||
if txt.startswith(b' ') and append_segment(result_segments, txt_buffer, txt_start, txt_end):
|
||||
txt_buffer = txt
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
continue
|
||||
|
||||
if txt.startswith(b', '):
|
||||
txt_buffer += b','
|
||||
append_segment(result_segments, txt_buffer, txt_start, txt_end)
|
||||
txt_buffer = txt.lstrip(b',')
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
continue
|
||||
|
||||
txt_buffer += txt
|
||||
txt_end = end
|
||||
|
||||
# Append the last segment
|
||||
append_segment(result_segments, txt_buffer, txt_start, txt_end)
|
||||
|
||||
else:
|
||||
for segment in segments:
|
||||
for word in self.get_value(segment, "words"):
|
||||
result_segments.append(
|
||||
Segment(
|
||||
int(self.get_value(word, "start") * 1000 + offset_ms),
|
||||
int(self.get_value(word, "end") * 1000 + offset_ms),
|
||||
self.get_value(word, "word"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
result_segments = [
|
||||
Segment(
|
||||
int(self.get_value(segment, "start") * 1000 + offset_ms),
|
||||
int(self.get_value(segment,"end") * 1000 + offset_ms),
|
||||
self.get_value(segment,"text"),
|
||||
)
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
return result_segments
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -23,6 +23,24 @@ if platform.system() == "Darwin" and platform.machine() == "arm64":
|
|||
except ImportError:
|
||||
logging.exception("")
|
||||
|
||||
def append_segment(result, txt: bytes, start: int, end: int):
|
||||
if txt == b'':
|
||||
return True
|
||||
|
||||
# try-catch will guard against multi-byte utf-8 characters
|
||||
# https://github.com/ggerganov/whisper.cpp/issues/1798
|
||||
try:
|
||||
result.append(
|
||||
Segment(
|
||||
start=start * 10, # centisecond to ms
|
||||
end=end * 10, # centisecond to ms
|
||||
text=txt.decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
|
||||
class WhisperCpp:
|
||||
def __init__(self, model: str) -> None:
|
||||
|
|
@ -40,25 +58,6 @@ class WhisperCpp:
|
|||
self.ctx = self.instance.init_from_file(model)
|
||||
self.segments: List[Segment] = []
|
||||
|
||||
def append_segment(self, txt: bytes, start: int, end: int):
|
||||
if txt == b'':
|
||||
return True
|
||||
|
||||
# try-catch will guard against multi-byte utf-8 characters
|
||||
# https://github.com/ggerganov/whisper.cpp/issues/1798
|
||||
try:
|
||||
self.segments.append(
|
||||
Segment(
|
||||
start=start * 10, # centisecond to ms
|
||||
end=end * 10, # centisecond to ms
|
||||
text=txt.decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
|
||||
def transcribe(self, audio: Union[np.ndarray, str], params: Any):
|
||||
self.segments = []
|
||||
|
||||
|
|
@ -87,7 +86,7 @@ class WhisperCpp:
|
|||
start = self.instance.full_get_segment_t0(self.ctx, i)
|
||||
end = self.instance.full_get_segment_t1(self.ctx, i)
|
||||
|
||||
if txt.startswith(b' ') and self.append_segment(txt_buffer, txt_start, txt_end):
|
||||
if txt.startswith(b' ') and append_segment(self.segments, txt_buffer, txt_start, txt_end):
|
||||
txt_buffer = txt
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
|
|
@ -95,7 +94,7 @@ class WhisperCpp:
|
|||
|
||||
if txt.startswith(b', '):
|
||||
txt_buffer += b','
|
||||
self.append_segment(txt_buffer, txt_start, txt_end)
|
||||
append_segment(self.segments, txt_buffer, txt_start, txt_end)
|
||||
txt_buffer = txt.lstrip(b',')
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
|
|
@ -105,7 +104,7 @@ class WhisperCpp:
|
|||
txt_end = end
|
||||
|
||||
# Append the last segment
|
||||
self.append_segment(txt_buffer, txt_start, txt_end)
|
||||
append_segment(self.segments, txt_buffer, txt_start, txt_end)
|
||||
|
||||
else:
|
||||
for i in range(n_segments):
|
||||
|
|
@ -113,7 +112,7 @@ class WhisperCpp:
|
|||
start = self.instance.full_get_segment_t0(self.ctx, i)
|
||||
end = self.instance.full_get_segment_t1(self.ctx, i)
|
||||
|
||||
self.append_segment(txt, start, end)
|
||||
append_segment(self.segments, txt, start, end)
|
||||
|
||||
return {
|
||||
"segments": self.segments,
|
||||
|
|
|
|||
|
|
@ -80,13 +80,10 @@ class FileTranscriptionFormWidget(QWidget):
|
|||
layout.addLayout(file_transcription_layout)
|
||||
self.setLayout(layout)
|
||||
|
||||
self.reset_word_level_timings()
|
||||
|
||||
def on_transcription_options_changed(
|
||||
self, transcription_options: TranscriptionOptions
|
||||
):
|
||||
self.transcription_options = transcription_options
|
||||
self.reset_word_level_timings()
|
||||
self.transcription_options_changed.emit(
|
||||
(self.transcription_options, self.file_transcription_options)
|
||||
)
|
||||
|
|
@ -125,9 +122,3 @@ class FileTranscriptionFormWidget(QWidget):
|
|||
)
|
||||
|
||||
return on_checkbox_state_changed
|
||||
|
||||
def reset_word_level_timings(self):
|
||||
self.word_level_timings_checkbox.setDisabled(
|
||||
self.transcription_options.model.model_type
|
||||
== ModelType.OPEN_AI_WHISPER_API
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ Open the Preferences window from the Menu bar, or click `Ctrl/Cmd + ,`.
|
|||
|
||||
**API Key** - key to authenticate your requests to OpenAI API. To get API key from OpenAI see [this article](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key).
|
||||
|
||||
**Base URL** - By default all requests are sent to API provided by OpenAI company. Their API URL is `https://api.openai.com/v1/`. Compatible APIs are also provided by other companies. List of available API URLs you can find on [discussion page](https://github.com/chidiwilliams/buzz/discussions/827)
|
||||
**Base URL** - By default all requests are sent to API provided by OpenAI company. Their API URL is `https://api.openai.com/v1/`. Compatible APIs are also provided by other companies. List of available API URLs and services to run yourself are available on [discussion page](https://github.com/chidiwilliams/buzz/discussions/827)
|
||||
|
||||
### Default export file name
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue