diff --git a/buzz/transcriber/whisper_cpp.py b/buzz/transcriber/whisper_cpp.py index 66029dcb..d77e2725 100644 --- a/buzz/transcriber/whisper_cpp.py +++ b/buzz/transcriber/whisper_cpp.py @@ -28,6 +28,10 @@ try: # On Windows whisper-server.exe subprocess will be used if (platform.system() == "Linux") and ((major > 1) or (major == 1 and minor >= 2)): from buzz.whisper_cpp_vulkan import whisper_cpp_vulkan + from buzz.whisper_cpp_vulkan.whisper_cpp_vulkan import ( + struct_whisper_context_params as struct_whisper_context_params_vulkan, + struct_whisper_aheads as struct_whisper_aheads_vulkan + ) IS_VULKAN_SUPPORTED = True LOADED_WHISPER_CPP_BINARY = True @@ -40,6 +44,10 @@ except (ImportError, Exception) as e: if not IS_VULKAN_SUPPORTED: try: from buzz.whisper_cpp import whisper_cpp # noqa: F401 + from buzz.whisper_cpp.whisper_cpp import ( + struct_whisper_context_params as struct_whisper_context_params_cpp, + struct_whisper_aheads as struct_whisper_aheads_cpp + ) LOADED_WHISPER_CPP_BINARY = True @@ -214,7 +222,23 @@ class WhisperCppCpu(WhisperCppInterface): return whisper_cpp.whisper_new_segment_callback(callback) def init_from_file(self, model: str): - return whisper_cpp.whisper_init_from_file(model.encode()) + force_cpu = os.getenv("BUZZ_FORCE_CPU", "false") + + aheads = struct_whisper_aheads_cpp() + aheads.n_heads = 0 + aheads.heads = None + params = struct_whisper_context_params_cpp( + use_gpu=force_cpu == "false", + flash_attn=False, + gpu_device=0, + dtw_token_timestamps=False, + dtw_aheads_preset=0, + dtw_n_top=0, + dtw_aheads=aheads, + dtw_mem_size=0 + ) + + return whisper_cpp.whisper_init_from_file_with_params(model.encode(), params) def full(self, ctx, params, audio, length): return whisper_cpp.whisper_full(ctx, params, audio, length) @@ -251,7 +275,23 @@ class WhisperCppVulkan(WhisperCppInterface): return whisper_cpp_vulkan.whisper_new_segment_callback(callback) def init_from_file(self, model: str): - return whisper_cpp_vulkan.whisper_init_from_file(model.encode()) + force_cpu = os.getenv("BUZZ_FORCE_CPU", "false") + + aheads = struct_whisper_aheads_vulkan() + aheads.n_heads = 0 + aheads.heads = None + params = struct_whisper_context_params_vulkan( + use_gpu=force_cpu == "false", + flash_attn=False, + gpu_device=0, + dtw_token_timestamps=False, + dtw_aheads_preset=0, + dtw_n_top=0, + dtw_aheads=aheads, + dtw_mem_size=0 + ) + + return whisper_cpp_vulkan.whisper_init_from_file_with_params(model.encode(), params) def full(self, ctx, params, audio, length): return whisper_cpp_vulkan.whisper_full(ctx, params, audio, length) diff --git a/docs/docs/preferences.md b/docs/docs/preferences.md index 0282d06a..c894a170 100644 --- a/docs/docs/preferences.md +++ b/docs/docs/preferences.md @@ -78,7 +78,7 @@ Alternatively you can set environment variables in your OS settings. See [this g ### Available variables -**BUZZ_WHISPERCPP_N_THREADS** - Number of threads to use for Whisper.cpp model. Default is `4`. +**BUZZ_WHISPERCPP_N_THREADS** - Number of threads to use for Whisper.cpp model. Default is half of available CPU cores. On a laptop with 16 threads setting `BUZZ_WHISPERCPP_N_THREADS=8` leads to some 15% speedup in transcription time. Increasing number of threads even more will lead in slower transcription time as results from parallel threads has to be