Adding support for Whisper API in live recordings (#807)

This commit is contained in:
Raivis Dejus 2024-06-21 08:26:49 +03:00 committed by GitHub
commit 4726d58af6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 104 additions and 37 deletions

View file

@ -8,8 +8,8 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-06-19 21:34+0300\n"
"PO-Revision-Date: 2024-06-19 21:34+0300\n"
"POT-Creation-Date: 2024-06-20 23:14+0300\n"
"PO-Revision-Date: 2024-06-20 23:15+0300\n"
"Last-Translator: \n"
"Language-Team: \n"
"Language: lv_LV\n"
@ -191,7 +191,7 @@ msgid "Stop"
msgstr "Apturēt"
#: buzz/widgets/transcriber/languages_combo_box.py:25
#: buzz/transcriber/transcriber.py:153
#: buzz/transcriber/transcriber.py:159
msgid "Detect Language"
msgstr "Noteikt valodu"
@ -204,8 +204,8 @@ msgid "Model:"
msgstr "Modelis:"
#: buzz/widgets/transcriber/transcription_options_group_box.py:89
msgid "Access Token:"
msgstr "Pieejas atslēga:"
msgid "Api Key:"
msgstr "API atslēga:"
#: buzz/widgets/transcriber/transcription_options_group_box.py:90
msgid "Task:"
@ -310,60 +310,51 @@ msgstr "Atcelts"
msgid "Queued"
msgstr "Ierindots"
#: buzz/widgets/transcription_tasks_table_widget.py:83
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:154
msgid "Translate"
msgstr "Tulkot"
#: buzz/widgets/transcription_tasks_table_widget.py:84
msgid "Tanscribe"
msgstr "Atpazīt"
#: buzz/widgets/transcription_tasks_table_widget.py:90
msgid "File Name / URL"
msgstr "Fails / URL"
#: buzz/widgets/transcription_tasks_table_widget.py:102
#: buzz/widgets/transcription_tasks_table_widget.py:96
msgid "Model"
msgstr "Modelis"
#: buzz/widgets/transcription_tasks_table_widget.py:111
#: buzz/widgets/transcription_tasks_table_widget.py:105
msgid "Task"
msgstr "Uzdevums"
#: buzz/widgets/transcription_tasks_table_widget.py:120
#: buzz/widgets/transcription_tasks_table_widget.py:114
msgid "Status"
msgstr "Statuss"
#: buzz/widgets/transcription_tasks_table_widget.py:128
#: buzz/widgets/transcription_tasks_table_widget.py:122
msgid "Date Added"
msgstr "Pievienots"
#: buzz/widgets/transcription_tasks_table_widget.py:139
#: buzz/widgets/transcription_tasks_table_widget.py:133
msgid "Date Completed"
msgstr "Pabeigts"
#: buzz/widgets/recording_transcriber_widget.py:71
#: buzz/widgets/recording_transcriber_widget.py:72
msgid "Live Recording"
msgstr "Dzīvā ierakstīšana"
#: buzz/widgets/recording_transcriber_widget.py:130
#: buzz/widgets/recording_transcriber_widget.py:136
msgid "Click Record to begin..."
msgstr "Klikšķiniet Ierakstīt, lai sāktu..."
#: buzz/widgets/recording_transcriber_widget.py:133
#: buzz/widgets/recording_transcriber_widget.py:139
msgid "Waiting for AI translation..."
msgstr "Gaida MI tulkojumu..."
#: buzz/widgets/recording_transcriber_widget.py:145
#: buzz/widgets/recording_transcriber_widget.py:151
msgid "Microphone:"
msgstr "Mikrofons:"
#: buzz/widgets/recording_transcriber_widget.py:397
#: buzz/widgets/recording_transcriber_widget.py:403
msgid "An error occurred while starting a new recording:"
msgstr "Sākot jaunu ierakstu notikusi kļūda:"
#: buzz/widgets/recording_transcriber_widget.py:401
#: buzz/widgets/recording_transcriber_widget.py:407
msgid ""
"Please check your audio devices or check the application logs for more "
"information."
@ -412,6 +403,11 @@ msgstr "Laiks"
msgid "Export"
msgstr "Eksportēt"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:154
#: buzz/transcriber/transcriber.py:24
msgid "Translate"
msgstr "Tulkot"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
msgid "API Key Required"
msgstr "API atslēgas kļūda"
@ -449,6 +445,14 @@ msgstr "Lai piešķirtu nepieciešamās atļaujas izpildiet šīs komandas"
msgid "Close"
msgstr "Aizvērt"
#: buzz/widgets/model_download_progress_dialog.py:36
msgid "Downloading model"
msgstr "Lejupielādē modeli"
#: buzz/widgets/model_download_progress_dialog.py:37
msgid "remaining"
msgstr "atlicis"
#: buzz/widgets/menu_bar.py:38
msgid "Import File..."
msgstr "Importēt failu..."
@ -485,6 +489,10 @@ msgstr "Izvēlieties audio failu"
msgid "Unable to save OpenAI API key to keyring"
msgstr "Neizdevās saglabāt OpenAI API atslēgu atslēgu saišķī"
#: buzz/transcriber/transcriber.py:25
msgid "Tanscribe"
msgstr "Atpazīt"
#: buzz/settings/shortcut.py:17
msgid "Open Record Window"
msgstr "Atvērt ieraksta logu"

View file

@ -76,10 +76,6 @@ class ModelType(enum.Enum):
ModelType.FASTER_WHISPER,
)
def supports_recording(self):
# Live transcription with OpenAI Whisper API not supported
return self != ModelType.OPEN_AI_WHISPER_API
def is_available(self):
if (
# Hide Whisper.cpp option if whisper.dll did not load correctly.

View file

@ -1,17 +1,22 @@
import datetime
import logging
import sys
import os
import wave
import tempfile
import threading
from typing import Optional
import numpy as np
import sounddevice
from PyQt6.QtCore import QObject, pyqtSignal
from sounddevice import PortAudioError
from openai import OpenAI
from PyQt6.QtCore import QObject, pyqtSignal
from buzz import transformers_whisper, whisper_audio
from buzz.model_loader import ModelType
from buzz.transcriber.transcriber import TranscriptionOptions
from buzz.settings.settings import Settings
from buzz.transcriber.transcriber import TranscriptionOptions, Task
from buzz.transcriber.whisper_cpp import WhisperCpp, whisper_cpp_params
from buzz.transformers_whisper import TransformersWhisper
@ -48,6 +53,7 @@ class RecordingTranscriber(QObject):
self.queue = np.ndarray([], dtype=np.float32)
self.mutex = threading.Lock()
self.sounddevice = sounddevice
self.openai_client = None
def start(self):
model_path = self.model_path
@ -59,6 +65,15 @@ class RecordingTranscriber(QObject):
model = WhisperCpp(model_path)
elif self.transcription_options.model.model_type == ModelType.FASTER_WHISPER:
model = faster_whisper.WhisperModel(model_path)
elif self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API:
settings = Settings()
custom_openai_base_url = settings.value(
key=Settings.Key.CUSTOM_OPENAI_BASE_URL, default_value=""
)
self.openai_client = OpenAI(
api_key=self.transcription_options.openai_access_token,
base_url=custom_openai_base_url if custom_openai_base_url else None
)
else: # ModelType.HUGGING_FACE
model = transformers_whisper.load_model(model_path)
@ -135,8 +150,10 @@ class RecordingTranscriber(QObject):
word_timestamps=self.transcription_options.word_level_timings,
)
result = {"text": " ".join([segment.text for segment in whisper_segments])}
else: # ModelType.HUGGING_FACE
elif (
self.transcription_options.model.model_type
== ModelType.HUGGING_FACE
):
assert isinstance(model, TransformersWhisper)
result = model.transcribe(
audio=samples,
@ -145,6 +162,44 @@ class RecordingTranscriber(QObject):
else "en",
task=self.transcription_options.task.value,
)
else: # OPEN_AI_WHISPER_API
assert self.openai_client is not None
# scale samples to 16-bit PCM
pcm_data = (samples * 32767).astype(np.int16).tobytes()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_filename = temp_file.name
with wave.open(temp_filename, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.sample_rate)
wf.writeframes(pcm_data)
with open(temp_filename, 'rb') as temp_file:
options = {
"model": "whisper-1",
"file": temp_file,
"response_format": "verbose_json",
"prompt": self.transcription_options.initial_prompt,
}
try:
transcript = (
self.openai_client.audio.transcriptions.create(
**options,
language=self.transcription_options.language,
)
if self.transcription_options.task == Task.TRANSCRIBE
else self.openai_client.audio.translations.create(**options)
)
result = {"text": " ".join(
[segment["text"] for segment in transcript.model_extra["segments"]])}
except Exception as e:
result = {"text": f"Error: {str(e)}"}
os.unlink(temp_filename)
next_text: str = result.get("text")

View file

@ -33,11 +33,13 @@ class ModelDownloadProgressDialog(QProgressDialog):
self.setCancelButton(cancel_button)
def update_label_text(self, fraction_completed: float):
label_text = f"{_('Downloading model')} ({fraction_completed:.0%}"
downloading_text = _("Downloading model")
remaining_text = _("remaining")
label_text = f"{downloading_text} ({fraction_completed:.0%}"
if fraction_completed > 0:
time_spent = (datetime.now() - self.start_time).total_seconds()
time_left = (time_spent / fraction_completed) - time_spent
label_text += f", {humanize.naturaldelta(time_left)} {_('remaining')}"
label_text += f", {humanize.naturaldelta(time_left)} {remaining_text}"
label_text += ")"
self.setLabelText(label_text)

View file

@ -17,6 +17,7 @@ from buzz.model_loader import (
TranscriptionModel,
ModelType,
)
from buzz.store.keyring_store import get_password, Key
from buzz.recording import RecordingAmplitudeListener
from buzz.settings.settings import Settings
from buzz.transcriber.recording_transcriber import RecordingTranscriber
@ -78,7 +79,7 @@ class RecordingTranscriberWidget(QWidget):
model_types = [
model_type
for model_type in ModelType
if model_type.is_available() and model_type.supports_recording()
if model_type.is_available()
]
default_model: Optional[TranscriptionModel] = None
if len(model_types) > 0:
@ -92,6 +93,10 @@ class RecordingTranscriberWidget(QWidget):
if selected_model is None or selected_model.model_type not in model_types:
selected_model = default_model
openai_access_token = ""
if selected_model.model_type == ModelType.OPEN_AI_WHISPER_API:
openai_access_token = get_password(key=Key.OPENAI_API_KEY)
self.transcription_options = TranscriptionOptions(
model=selected_model,
task=self.settings.value(
@ -99,6 +104,7 @@ class RecordingTranscriberWidget(QWidget):
default_value=Task.TRANSCRIBE,
),
language=default_language if default_language != "" else None,
openai_access_token=openai_access_token,
initial_prompt=self.settings.value(
key=Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, default_value=""
),

View file

@ -86,7 +86,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
self.form_layout.addRow(_("Model:"), self.model_type_combo_box)
self.form_layout.addRow("", self.whisper_model_size_combo_box)
self.form_layout.addRow("", self.hugging_face_search_line_edit)
self.form_layout.addRow(_("Access Token:"), self.openai_access_token_edit)
self.form_layout.addRow(_("Api Key:"), self.openai_access_token_edit)
self.form_layout.addRow(_("Task:"), self.tasks_combo_box)
self.form_layout.addRow(_("Language:"), self.languages_combo_box)