mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 22:55:46 +01:00
Switching to pipeline for HF whisper (#814)
This commit is contained in:
parent
cf340bc7d4
commit
3d8f5da492
4 changed files with 44 additions and 79 deletions
|
|
@ -81,7 +81,7 @@ class RecordingTranscriber(QObject):
|
|||
logging.debug("Will use whisper API on %s, %s",
|
||||
custom_openai_base_url, self.whisper_api_model)
|
||||
else: # ModelType.HUGGING_FACE
|
||||
model = transformers_whisper.load_model(model_path)
|
||||
model = TransformersWhisper(model_path)
|
||||
|
||||
initial_prompt = self.transcription_options.initial_prompt
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from typing import Optional, List
|
|||
import tqdm
|
||||
from PyQt6.QtCore import QObject
|
||||
|
||||
from buzz import transformers_whisper
|
||||
from buzz.conn import pipe_stderr
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.transformers_whisper import TransformersWhisper
|
||||
from buzz.transcriber.file_transcriber import FileTranscriber
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment
|
||||
|
||||
|
|
@ -87,7 +87,10 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
) -> None:
|
||||
with pipe_stderr(stderr_conn):
|
||||
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
|
||||
# TODO Find a way to emmit real progress
|
||||
sys.stderr.write("0%\n")
|
||||
segments = cls.transcribe_hugging_face(task)
|
||||
sys.stderr.write("100%\n")
|
||||
elif (
|
||||
task.transcription_options.model.model_type == ModelType.FASTER_WHISPER
|
||||
):
|
||||
|
|
@ -105,7 +108,7 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
|
||||
@classmethod
|
||||
def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
|
||||
model = transformers_whisper.load_model(task.model_path)
|
||||
model = TransformersWhisper(task.model_path)
|
||||
language = (
|
||||
task.transcription_options.language
|
||||
if task.transcription_options.language is not None
|
||||
|
|
@ -115,7 +118,6 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
audio=task.file_path,
|
||||
language=language,
|
||||
task=task.transcription_options.task.value,
|
||||
verbose=False,
|
||||
)
|
||||
return [
|
||||
Segment(
|
||||
|
|
|
|||
|
|
@ -1,98 +1,61 @@
|
|||
import sys
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import whisper
|
||||
import torch
|
||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
def cuda_is_viable(min_vram_gb=10):
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 # Convert bytes to GB
|
||||
if total_memory < min_vram_gb:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_model(model_name_or_path: str):
|
||||
processor = WhisperProcessor.from_pretrained(model_name_or_path)
|
||||
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path)
|
||||
|
||||
if cuda_is_viable():
|
||||
logging.debug("CUDA is available and has enough VRAM, moving model to GPU.")
|
||||
model.to("cuda")
|
||||
|
||||
return TransformersWhisper(processor, model)
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
|
||||
|
||||
class TransformersWhisper:
|
||||
def __init__(
|
||||
self, processor: WhisperProcessor, model: WhisperForConditionalGeneration
|
||||
self, model_id: str
|
||||
):
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.SAMPLE_RATE = whisper.audio.SAMPLE_RATE
|
||||
self.N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES
|
||||
self.model_id = model_id
|
||||
|
||||
# Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and
|
||||
# timestamps are available. See: https://github.com/huggingface/transformers/issues/19887,
|
||||
# https://github.com/huggingface/transformers/pull/20620.
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
language: str,
|
||||
task: str,
|
||||
verbose: Optional[bool] = None,
|
||||
):
|
||||
if isinstance(audio, str):
|
||||
audio = whisper.load_audio(audio, sr=self.SAMPLE_RATE)
|
||||
|
||||
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(
|
||||
task=task, language=language
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.model_id, torch_dtype=torch_dtype, use_safetensors=True
|
||||
)
|
||||
|
||||
model.generation_config.language = language
|
||||
model.to(device)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(self.model_id)
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
generate_kwargs={"language": language, "task": task},
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
chunk_length_s=30,
|
||||
torch_dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transcript = pipe(audio, return_timestamps=True)
|
||||
|
||||
segments = []
|
||||
all_predicted_ids = []
|
||||
|
||||
num_samples = audio.size
|
||||
seek = 0
|
||||
with tqdm(
|
||||
total=num_samples, unit="samples", disable=verbose is not False
|
||||
) as progress_bar:
|
||||
while seek < num_samples:
|
||||
chunk = audio[seek : seek + self.N_SAMPLES_IN_CHUNK]
|
||||
input_features = self.processor(
|
||||
chunk, return_tensors="pt", sampling_rate=self.SAMPLE_RATE
|
||||
).input_features.to(self.model.device)
|
||||
predicted_ids = self.model.generate(input_features)
|
||||
all_predicted_ids.extend(predicted_ids)
|
||||
text: str = self.processor.batch_decode(
|
||||
predicted_ids, skip_special_tokens=True
|
||||
)[0]
|
||||
if text.strip() != "":
|
||||
segments.append(
|
||||
{
|
||||
"start": seek / self.SAMPLE_RATE,
|
||||
"end": min(seek + self.N_SAMPLES_IN_CHUNK, num_samples)
|
||||
/ self.SAMPLE_RATE,
|
||||
"text": text,
|
||||
}
|
||||
)
|
||||
|
||||
progress_bar.update(
|
||||
min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek
|
||||
)
|
||||
seek += self.N_SAMPLES_IN_CHUNK
|
||||
for chunk in transcript['chunks']:
|
||||
start, end = chunk['timestamp']
|
||||
text = chunk['text']
|
||||
segments.append({
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": text,
|
||||
"translation": ""
|
||||
})
|
||||
|
||||
return {
|
||||
"text": self.processor.batch_decode(
|
||||
all_predicted_ids, skip_special_tokens=True
|
||||
)[0],
|
||||
"text": transcript['text'],
|
||||
"segments": segments,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import platform
|
||||
import pytest
|
||||
|
||||
from buzz.transformers_whisper import load_model
|
||||
from buzz.transformers_whisper import TransformersWhisper
|
||||
from tests.audio import test_audio_path
|
||||
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ from tests.audio import test_audio_path
|
|||
)
|
||||
class TestTransformersWhisper:
|
||||
def test_should_transcribe(self):
|
||||
model = load_model("openai/whisper-tiny")
|
||||
model = TransformersWhisper("openai/whisper-tiny")
|
||||
result = model.transcribe(
|
||||
audio=test_audio_path, language="fr", task="transcribe"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue