diff --git a/buzz/transcriber/recording_transcriber.py b/buzz/transcriber/recording_transcriber.py index af0c2d71..1635bd54 100644 --- a/buzz/transcriber/recording_transcriber.py +++ b/buzz/transcriber/recording_transcriber.py @@ -7,6 +7,7 @@ import tempfile import threading from typing import Optional +import torch import numpy as np import sounddevice from sounddevice import PortAudioError @@ -60,7 +61,8 @@ class RecordingTranscriber(QObject): keep_samples = int(0.15 * self.sample_rate) if self.transcription_options.model.model_type == ModelType.WHISPER: - model = whisper.load_model(model_path) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = whisper.load_model(model_path, device=device) elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP: model = WhisperCpp(model_path) elif self.transcription_options.model.model_type == ModelType.FASTER_WHISPER: diff --git a/buzz/transcriber/whisper_file_transcriber.py b/buzz/transcriber/whisper_file_transcriber.py index bd024fcf..abb34d86 100644 --- a/buzz/transcriber/whisper_file_transcriber.py +++ b/buzz/transcriber/whisper_file_transcriber.py @@ -4,6 +4,7 @@ import logging import multiprocessing import re import sys +import torch from multiprocessing.connection import Connection from threading import Thread from typing import Optional, List @@ -168,7 +169,8 @@ class WhisperFileTranscriber(FileTranscriber): @classmethod def transcribe_openai_whisper(cls, task: FileTranscriptionTask) -> List[Segment]: - model = whisper.load_model(task.model_path) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = whisper.load_model(task.model_path, device=device) if task.transcription_options.word_level_timings: stable_whisper.modify_model(model)