buzz/buzz/transcriber/whisper_cpp.py
Raivis Dejus e22ded1c2b Build adjustments
# Conflicts:
#	buzz/transcriber/whisper_cpp.py
2025-06-18 18:20:41 +03:00

310 lines
No EOL
10 KiB
Python

import platform
import os
import ctypes
import logging
from typing import Union, Any, List
import numpy as np
from buzz import whisper_audio
from buzz.model_loader import LOADED_WHISPER_CPP_BINARY
from buzz.transcriber.transcriber import Segment, Task, TranscriptionOptions
IS_COREML_SUPPORTED = False
if platform.system() == "Darwin" and platform.machine() == "arm64":
try:
from buzz.whisper_cpp_coreml import whisper_cpp_coreml # noqa: F401
IS_COREML_SUPPORTED = True
except ImportError:
IS_COREML_SUPPORTED = False
IS_VULKAN_SUPPORTED = False
try:
import vulkan
instance = vulkan.vkCreateInstance(vulkan.VkInstanceCreateInfo(), None)
vulkan.vkDestroyInstance(instance, None)
IS_VULKAN_SUPPORTED = True
from buzz.whisper_cpp_vulkan import whisper_cpp_vulkan
except (ImportError, Exception):
IS_VULKAN_SUPPORTED = False
if not IS_VULKAN_SUPPORTED and not IS_COREML_SUPPORTED and LOADED_WHISPER_CPP_BINARY:
from buzz.whisper_cpp import whisper_cpp as whisper_cpp_cpu
def append_segment(result, txt: bytes, start: int, end: int):
if txt == b'':
return True
# try-catch will guard against multi-byte utf-8 characters
# https://github.com/ggerganov/whisper.cpp/issues/1798
try:
result.append(
Segment(
start=start * 10, # centisecond to ms
end=end * 10, # centisecond to ms
text=txt.decode("utf-8"),
)
)
return True
except UnicodeDecodeError:
return False
class WhisperCpp:
def __init__(self, model: str) -> None:
self.is_coreml_supported = IS_COREML_SUPPORTED
self.is_vulkan_supported = IS_VULKAN_SUPPORTED
if self.is_coreml_supported:
coreml_model = model.replace(".bin", "-encoder.mlmodelc")
if not os.path.exists(coreml_model):
self.is_coreml_supported = False
logging.debug(f"WhisperCpp model {model}, "
f"(Core ML: {self.is_coreml_supported}, Vulkan: {self.is_vulkan_supported})")
self.instance = self.get_instance()
self.ctx = self.instance.init_from_file(model)
self.segments: List[Segment] = []
def transcribe(self, audio: Union[np.ndarray, str], params: Any):
self.segments = []
if isinstance(audio, str):
audio = whisper_audio.load_audio(audio)
logging.debug("Loaded audio with length = %s", len(audio))
whisper_cpp_audio = audio.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
result = self.instance.full(
self.ctx, params, whisper_cpp_audio, len(audio)
)
if result != 0:
raise Exception(f"Error from whisper.cpp: {result}")
n_segments = self.instance.full_n_segments(self.ctx)
if params.token_timestamps:
# Will process word timestamps
txt_buffer = b''
txt_start = 0
txt_end = 0
for i in range(n_segments):
txt = self.instance.full_get_segment_text(self.ctx, i)
start = self.instance.full_get_segment_t0(self.ctx, i)
end = self.instance.full_get_segment_t1(self.ctx, i)
if txt.startswith(b' ') and append_segment(self.segments, txt_buffer, txt_start, txt_end):
txt_buffer = txt
txt_start = start
txt_end = end
continue
if txt.startswith(b', '):
txt_buffer += b','
append_segment(self.segments, txt_buffer, txt_start, txt_end)
txt_buffer = txt.lstrip(b',')
txt_start = start
txt_end = end
continue
txt_buffer += txt
txt_end = end
# Append the last segment
append_segment(self.segments, txt_buffer, txt_start, txt_end)
else:
for i in range(n_segments):
txt = self.instance.full_get_segment_text(self.ctx, i)
start = self.instance.full_get_segment_t0(self.ctx, i)
end = self.instance.full_get_segment_t1(self.ctx, i)
append_segment(self.segments, txt, start, end)
return {
"segments": self.segments,
"text": "".join([segment.text for segment in self.segments]),
}
def get_instance(self):
if self.is_coreml_supported:
return WhisperCppCoreML()
if self.is_vulkan_supported:
return WhisperCppVulkan()
return WhisperCppCpu()
def get_params(
self,
transcription_options: TranscriptionOptions,
print_realtime=False,
print_progress=False,
):
params = self.instance.full_default_params(0) # WHISPER_SAMPLING_GREEDY
params.n_threads = int(os.getenv("BUZZ_WHISPERCPP_N_THREADS", (os.cpu_count() or 8)//2))
params.print_realtime = print_realtime
params.print_progress = print_progress
params.language = self.instance.get_string((transcription_options.language or "en"))
params.translate = transcription_options.task == Task.TRANSLATE
params.max_len = ctypes.c_int(1)
params.max_len = 1 if transcription_options.word_level_timings else 0
params.token_timestamps = transcription_options.word_level_timings
params.initial_prompt = self.instance.get_string(transcription_options.initial_prompt)
return params
def __del__(self):
if self.instance and self.ctx:
self.instance.free(self.ctx)
class WhisperCppInterface:
def full_default_params(self, sampling: int):
raise NotImplementedError
def get_string(self, string: str):
raise NotImplementedError
def get_encoder_begin_callback(self, callback):
raise NotImplementedError
def get_new_segment_callback(self, callback):
raise NotImplementedError
def init_from_file(self, model: str):
raise NotImplementedError
def full(self, ctx, params, audio, length):
raise NotImplementedError
def full_n_segments(self, ctx):
raise NotImplementedError
def full_get_segment_text(self, ctx, i):
raise NotImplementedError
def full_get_segment_t0(self, ctx, i):
raise NotImplementedError
def full_get_segment_t1(self, ctx, i):
raise NotImplementedError
def free(self, ctx):
raise NotImplementedError
class WhisperCppCpu(WhisperCppInterface):
def full_default_params(self, sampling: int):
return whisper_cpp_cpu.whisper_full_default_params(sampling)
def get_string(self, string: str):
return whisper_cpp_cpu.String(string.encode())
def get_encoder_begin_callback(self, callback):
return whisper_cpp_cpu.whisper_encoder_begin_callback(callback)
def get_new_segment_callback(self, callback):
return whisper_cpp_cpu.whisper_new_segment_callback(callback)
def init_from_file(self, model: str):
return whisper_cpp_cpu.whisper_init_from_file(model.encode())
def full(self, ctx, params, audio, length):
return whisper_cpp_cpu.whisper_full(ctx, params, audio, length)
def full_n_segments(self, ctx):
return whisper_cpp_cpu.whisper_full_n_segments(ctx)
def full_get_segment_text(self, ctx, i):
return whisper_cpp_cpu.whisper_full_get_segment_text(ctx, i)
def full_get_segment_t0(self, ctx, i):
return whisper_cpp_cpu.whisper_full_get_segment_t0(ctx, i)
def full_get_segment_t1(self, ctx, i):
return whisper_cpp_cpu.whisper_full_get_segment_t1(ctx, i)
def free(self, ctx):
if ctx and whisper_cpp_cpu is not None:
return whisper_cpp_cpu.whisper_free(ctx)
return None
class WhisperCppCoreML(WhisperCppInterface):
def full_default_params(self, sampling: int):
return whisper_cpp_coreml.whisper_full_default_params(sampling)
def get_string(self, string: str):
return whisper_cpp_coreml.String(string.encode())
def get_encoder_begin_callback(self, callback):
return whisper_cpp_coreml.whisper_encoder_begin_callback(callback)
def get_new_segment_callback(self, callback):
return whisper_cpp_coreml.whisper_new_segment_callback(callback)
def init_from_file(self, model: str):
return whisper_cpp_coreml.whisper_init_from_file(model.encode())
def full(self, ctx, params, audio, length):
return whisper_cpp_coreml.whisper_full(ctx, params, audio, length)
def full_n_segments(self, ctx):
return whisper_cpp_coreml.whisper_full_n_segments(ctx)
def full_get_segment_text(self, ctx, i):
return whisper_cpp_coreml.whisper_full_get_segment_text(ctx, i)
def full_get_segment_t0(self, ctx, i):
return whisper_cpp_coreml.whisper_full_get_segment_t0(ctx, i)
def full_get_segment_t1(self, ctx, i):
return whisper_cpp_coreml.whisper_full_get_segment_t1(ctx, i)
def free(self, ctx):
if ctx and whisper_cpp_coreml is not None:
return whisper_cpp_coreml.whisper_free(ctx)
return None
class WhisperCppVulkan(WhisperCppInterface):
def full_default_params(self, sampling: int):
return whisper_cpp_vulkan.whisper_full_default_params(sampling)
def get_string(self, string: str):
return whisper_cpp_vulkan.String(string.encode())
def get_encoder_begin_callback(self, callback):
return whisper_cpp_vulkan.whisper_encoder_begin_callback(callback)
def get_new_segment_callback(self, callback):
return whisper_cpp_vulkan.whisper_new_segment_callback(callback)
def init_from_file(self, model: str):
return whisper_cpp_vulkan.whisper_init_from_file(model.encode())
def full(self, ctx, params, audio, length):
return whisper_cpp_vulkan.whisper_full(ctx, params, audio, length)
def full_n_segments(self, ctx):
return whisper_cpp_vulkan.whisper_full_n_segments(ctx)
def full_get_segment_text(self, ctx, i):
return whisper_cpp_vulkan.whisper_full_get_segment_text(ctx, i)
def full_get_segment_t0(self, ctx, i):
return whisper_cpp_vulkan.whisper_full_get_segment_t0(ctx, i)
def full_get_segment_t1(self, ctx, i):
return whisper_cpp_vulkan.whisper_full_get_segment_t1(ctx, i)
def free(self, ctx):
if ctx and whisper_cpp_vulkan is not None:
return whisper_cpp_vulkan.whisper_free(ctx)
return None