mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-26 11:40:09 +02:00
Add support for Hugging Face models (#264)
This commit is contained in:
parent
82bdd30fb8
commit
3dceb11c4d
|
@ -7,4 +7,4 @@ omit =
|
|||
directory = coverage/html
|
||||
|
||||
[report]
|
||||
fail_under = 75
|
||||
fail_under = 78
|
||||
|
|
2
Makefile
2
Makefile
|
@ -43,7 +43,7 @@ clean:
|
|||
rm -rf dist/* || true
|
||||
|
||||
test: buzz/whisper_cpp.py
|
||||
pytest --cov=buzz --cov-report=xml --cov-report=html
|
||||
pytest -vv --cov=buzz --cov-report=xml --cov-report=html
|
||||
|
||||
dist/Buzz dist/Buzz.app: buzz/whisper_cpp.py
|
||||
pyinstaller --noconfirm Buzz.spec
|
||||
|
|
|
@ -23,7 +23,7 @@ class TasksCache:
|
|||
return pickle.load(file)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
except pickle.UnpicklingError: # delete corrupted cache
|
||||
except (pickle.UnpicklingError, AttributeError): # delete corrupted cache
|
||||
os.remove(self.file_path)
|
||||
return []
|
||||
|
||||
|
|
248
buzz/gui.py
248
buzz/gui.py
|
@ -1,4 +1,7 @@
|
|||
import dataclasses
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
|
@ -10,27 +13,28 @@ import humanize
|
|||
import sounddevice
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (QDateTime, QObject, Qt, QThread,
|
||||
QTimer, QUrl, pyqtSignal, QModelIndex, QSize)
|
||||
QTimer, QUrl, pyqtSignal, QModelIndex, QSize, QPoint,
|
||||
QUrlQuery, QMetaObject, QEvent)
|
||||
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
|
||||
QKeySequence, QPixmap, QTextCursor, QValidator)
|
||||
QKeySequence, QPixmap, QTextCursor, QValidator, QKeyEvent)
|
||||
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
|
||||
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
|
||||
QDialogButtonBox, QFileDialog, QLabel, QLineEdit,
|
||||
QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QMenu,
|
||||
QWidget, QGroupBox, QToolBar, QTableWidget, QMenuBar, QFormLayout, QTableWidgetItem,
|
||||
QHeaderView, QAbstractItemView)
|
||||
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem)
|
||||
from requests import get
|
||||
from whisper import tokenizer
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
|
||||
from .__version__ import VERSION
|
||||
from .model_loader import ModelLoader
|
||||
from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel
|
||||
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
|
||||
RecordingTranscriber, Task,
|
||||
WhisperCppFileTranscriber, WhisperFileTranscriber,
|
||||
get_default_output_file_path, segments_to_text, write_output, TranscriptionOptions,
|
||||
Model, FileTranscriberQueueWorker, FileTranscriptionTask)
|
||||
FileTranscriberQueueWorker, FileTranscriptionTask)
|
||||
|
||||
APP_NAME = 'Buzz'
|
||||
|
||||
|
@ -99,7 +103,7 @@ class AudioDevicesComboBox(QComboBox):
|
|||
|
||||
class LanguagesComboBox(QComboBox):
|
||||
"""LanguagesComboBox displays a list of languages available to use with Whisper"""
|
||||
# language is a languge key from whisper.tokenizer.LANGUAGES or '' for "detect language"
|
||||
# language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language"
|
||||
languageChanged = pyqtSignal(str)
|
||||
|
||||
def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None:
|
||||
|
@ -136,20 +140,6 @@ class TasksComboBox(QComboBox):
|
|||
self.taskChanged.emit(self.tasks[index])
|
||||
|
||||
|
||||
class ModelComboBox(QComboBox):
|
||||
model_changed = pyqtSignal(Model)
|
||||
|
||||
def __init__(self, default_model: Model, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(parent, *args)
|
||||
self.models = [model for model in Model]
|
||||
self.addItems([model.value for model in self.models])
|
||||
self.currentIndexChanged.connect(self.on_index_changed)
|
||||
self.setCurrentText(default_model.value)
|
||||
|
||||
def on_index_changed(self, index: int):
|
||||
self.model_changed.emit(self.models[index])
|
||||
|
||||
|
||||
class TextDisplayBox(QPlainTextEdit):
|
||||
"""TextDisplayBox is a read-only textbox"""
|
||||
|
||||
|
@ -202,6 +192,7 @@ class DownloadModelProgressDialog(QProgressDialog):
|
|||
'Cancel', 0, 100, parent, *args)
|
||||
self.setWindowModality(Qt.WindowModality.ApplicationModal)
|
||||
self.start_time = datetime.now()
|
||||
self.setFixedSize(self.size())
|
||||
|
||||
def set_fraction_completed(self, fraction_completed: float) -> None:
|
||||
self.setValue(int(fraction_completed * self.maximum()))
|
||||
|
@ -280,7 +271,7 @@ class TimerLabel(QLabel):
|
|||
|
||||
|
||||
def show_model_download_error_dialog(parent: QWidget, error: str):
|
||||
message = f'Unable to load the Whisper model: {error}. Please retry or check the application logs for more ' \
|
||||
message = f"An error occurred while loading the Whisper model: {error}{'' if error.endswith('.') else '.'}" \
|
||||
f'information. '
|
||||
QMessageBox.critical(parent, '', message)
|
||||
|
||||
|
@ -302,7 +293,6 @@ class FileTranscriberWidget(QWidget):
|
|||
super().__init__(parent, flags)
|
||||
|
||||
self.setWindowTitle(file_paths_as_title(file_paths))
|
||||
self.setFixedSize(420, 270)
|
||||
|
||||
self.file_paths = file_paths
|
||||
self.transcription_options = TranscriptionOptions()
|
||||
|
@ -332,15 +322,19 @@ class FileTranscriberWidget(QWidget):
|
|||
layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setFixedSize(self.sizeHint())
|
||||
|
||||
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
|
||||
self.transcription_options = transcription_options
|
||||
self.word_level_timings_checkbox.setDisabled(
|
||||
self.transcription_options.model.model_type == ModelType.HUGGING_FACE)
|
||||
|
||||
def on_click_run(self):
|
||||
self.run_button.setDisabled(True)
|
||||
|
||||
self.transcriber_thread = QThread()
|
||||
self.model_loader = ModelLoader(model=self.transcription_options.model)
|
||||
self.model_loader.moveToThread(self.transcriber_thread)
|
||||
|
||||
self.transcriber_thread.started.connect(self.model_loader.run)
|
||||
self.model_loader.finished.connect(
|
||||
|
@ -378,6 +372,7 @@ class FileTranscriberWidget(QWidget):
|
|||
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
self.reset_model_download()
|
||||
show_model_download_error_dialog(self, error)
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
|
@ -391,6 +386,7 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
def reset_model_download(self):
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.close()
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
def on_word_level_timings_changed(self, value: int):
|
||||
|
@ -477,9 +473,9 @@ class RecordingTranscriberWidget(QWidget):
|
|||
layout = QVBoxLayout(self)
|
||||
|
||||
self.setWindowTitle('Live Recording')
|
||||
self.setFixedSize(400, 520)
|
||||
|
||||
self.transcription_options = TranscriptionOptions(model=Model.WHISPER_CPP_TINY)
|
||||
self.transcription_options = TranscriptionOptions(model=TranscriptionModel(model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY))
|
||||
|
||||
self.audio_devices_combo_box = AudioDevicesComboBox(self)
|
||||
self.audio_devices_combo_box.device_changed.connect(
|
||||
|
@ -514,6 +510,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
layout.addWidget(self.text_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setFixedSize(self.sizeHint())
|
||||
|
||||
def closeEvent(self, event: QCloseEvent) -> None:
|
||||
self.stop_recording()
|
||||
|
@ -534,8 +531,8 @@ class RecordingTranscriberWidget(QWidget):
|
|||
def start_recording(self):
|
||||
self.record_button.setDisabled(True)
|
||||
|
||||
use_whisper_cpp = self.transcription_options.model.is_whisper_cpp(
|
||||
) and self.transcription_options.language is not None
|
||||
use_whisper_cpp = self.transcription_options.model.model_type == ModelType.WHISPER_CPP and \
|
||||
self.transcription_options.language is not None
|
||||
|
||||
def start_recording_transcription(model_path: str):
|
||||
# Clear text box placeholder because the first chunk takes a while to process
|
||||
|
@ -593,6 +590,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
self.reset_model_download()
|
||||
show_model_download_error_dialog(self, error)
|
||||
self.stop_recording()
|
||||
self.record_button.force_stop()
|
||||
|
@ -620,6 +618,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
def reset_model_download(self):
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.close()
|
||||
self.model_download_progress_dialog = None
|
||||
|
||||
|
||||
|
@ -964,6 +963,131 @@ class MainWindow(QMainWindow):
|
|||
super().closeEvent(event)
|
||||
|
||||
|
||||
class LineEdit(QLineEdit):
|
||||
def __init__(self, default_text: str = '', parent: Optional[QWidget] = None):
|
||||
super().__init__(default_text, parent)
|
||||
if platform.system() == 'Darwin':
|
||||
self.setStyleSheet('QLineEdit { padding: 4px }')
|
||||
|
||||
|
||||
# Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py
|
||||
class HuggingFaceSearchLineEdit(LineEdit):
|
||||
model_selected = pyqtSignal(str)
|
||||
popup: QListWidget
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None):
|
||||
super().__init__('', parent)
|
||||
|
||||
self.setMinimumWidth(150)
|
||||
self.setPlaceholderText('openai/whisper-tiny')
|
||||
|
||||
self.timer = QTimer(self)
|
||||
self.timer.setSingleShot(True)
|
||||
self.timer.setInterval(250)
|
||||
self.timer.timeout.connect(self.fetch_models)
|
||||
|
||||
# Restart debounce timer each time editor text changes
|
||||
self.textEdited.connect(self.timer.start)
|
||||
self.textEdited.connect(self.on_text_edited)
|
||||
|
||||
self.network_manager = QNetworkAccessManager(self)
|
||||
self.network_manager.finished.connect(self.on_request_response)
|
||||
|
||||
self.popup = QListWidget()
|
||||
self.popup.setWindowFlags(Qt.WindowType.Popup)
|
||||
self.popup.setFocusProxy(self)
|
||||
self.popup.setMouseTracking(True)
|
||||
self.popup.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
|
||||
self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
||||
self.popup.installEventFilter(self)
|
||||
self.popup.itemClicked.connect(self.on_select_item)
|
||||
|
||||
def on_text_edited(self, text: str):
|
||||
self.model_selected.emit(text)
|
||||
|
||||
def on_select_item(self):
|
||||
self.popup.hide()
|
||||
self.setFocus()
|
||||
|
||||
item = self.popup.currentItem()
|
||||
self.setText(item.text())
|
||||
QMetaObject.invokeMethod(self, 'returnPressed')
|
||||
self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole))
|
||||
|
||||
def fetch_models(self):
|
||||
text = self.text()
|
||||
if len(text) < 3:
|
||||
return
|
||||
|
||||
url = QUrl("https://huggingface.co/api/models")
|
||||
|
||||
query = QUrlQuery()
|
||||
query.addQueryItem("filter", "whisper")
|
||||
query.addQueryItem("search", text)
|
||||
|
||||
url.setQuery(query)
|
||||
|
||||
return self.network_manager.get(QNetworkRequest(url))
|
||||
|
||||
def on_popup_selected(self):
|
||||
self.timer.stop()
|
||||
|
||||
def on_request_response(self, network_reply: QNetworkReply):
|
||||
if network_reply.error() != QNetworkReply.NetworkError.NoError:
|
||||
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
|
||||
return
|
||||
|
||||
models = json.loads(network_reply.readAll().data())
|
||||
|
||||
self.popup.setUpdatesEnabled(False)
|
||||
self.popup.clear()
|
||||
|
||||
for model in models:
|
||||
model_id = model.get('id')
|
||||
|
||||
item = QListWidgetItem(self.popup)
|
||||
item.setText(model_id)
|
||||
item.setData(Qt.ItemDataRole.UserRole, model_id)
|
||||
|
||||
self.popup.setCurrentItem(self.popup.item(0))
|
||||
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
|
||||
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
|
||||
self.popup.setUpdatesEnabled(True)
|
||||
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
|
||||
self.popup.setFocus()
|
||||
self.popup.show()
|
||||
|
||||
def eventFilter(self, target: QObject, event: QEvent):
|
||||
if hasattr(self, 'popup') is False or target != self.popup:
|
||||
return False
|
||||
|
||||
if event.type() == QEvent.Type.MouseButtonPress:
|
||||
self.popup.hide()
|
||||
self.setFocus()
|
||||
return True
|
||||
|
||||
if isinstance(event, QKeyEvent):
|
||||
key = event.key()
|
||||
if key in [Qt.Key.Key_Enter, Qt.Key.Key_Return]:
|
||||
self.on_select_item()
|
||||
return True
|
||||
|
||||
if key == Qt.Key.Key_Escape:
|
||||
self.setFocus()
|
||||
self.popup.hide()
|
||||
return True
|
||||
|
||||
if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp,
|
||||
Qt.Key.Key_PageDown]:
|
||||
return False
|
||||
|
||||
self.setFocus()
|
||||
self.event(event)
|
||||
self.popup.hide()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TranscriptionOptionsGroupBox(QGroupBox):
|
||||
transcription_options: TranscriptionOptions
|
||||
transcription_options_changed = pyqtSignal(TranscriptionOptions)
|
||||
|
@ -972,7 +1096,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
super().__init__(title='', parent=parent)
|
||||
self.transcription_options = default_transcription_options
|
||||
|
||||
layout = QFormLayout(self)
|
||||
self.form_layout = QFormLayout(self)
|
||||
|
||||
self.tasks_combo_box = TasksComboBox(
|
||||
default_task=self.transcription_options.task,
|
||||
|
@ -985,30 +1109,41 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.languages_combo_box.languageChanged.connect(
|
||||
self.on_language_changed)
|
||||
|
||||
self.model_combo_box = ModelComboBox(
|
||||
default_model=self.transcription_options.model,
|
||||
parent=self)
|
||||
self.model_combo_box.model_changed.connect(self.on_model_changed)
|
||||
|
||||
self.advanced_settings_button = AdvancedSettingsButton(self)
|
||||
self.advanced_settings_button.clicked.connect(
|
||||
self.open_advanced_settings)
|
||||
|
||||
layout.addRow('Task:', self.tasks_combo_box)
|
||||
layout.addRow('Language:', self.languages_combo_box)
|
||||
layout.addRow('Model:', self.model_combo_box)
|
||||
layout.addRow('', self.advanced_settings_button)
|
||||
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
|
||||
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)
|
||||
|
||||
self.setLayout(layout)
|
||||
model_type_combo_box = QComboBox(self)
|
||||
model_type_combo_box.addItems([model_type.value for model_type in ModelType])
|
||||
model_type_combo_box.setCurrentText(default_transcription_options.model.model_type.value)
|
||||
model_type_combo_box.currentTextChanged.connect(self.on_model_type_changed)
|
||||
|
||||
self.whisper_model_size_combo_box = QComboBox(self)
|
||||
self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize])
|
||||
if default_transcription_options.model.whisper_model_size is not None:
|
||||
self.whisper_model_size_combo_box.setCurrentText(
|
||||
default_transcription_options.model.whisper_model_size.value.title())
|
||||
self.whisper_model_size_combo_box.currentTextChanged.connect(self.on_whisper_model_size_changed)
|
||||
|
||||
self.form_layout.addRow('Task:', self.tasks_combo_box)
|
||||
self.form_layout.addRow('Language:', self.languages_combo_box)
|
||||
self.form_layout.addRow('Model:', model_type_combo_box)
|
||||
self.form_layout.addRow('', self.whisper_model_size_combo_box)
|
||||
self.form_layout.addRow('', self.hugging_face_search_line_edit)
|
||||
|
||||
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, False)
|
||||
|
||||
self.form_layout.addRow('', self.advanced_settings_button)
|
||||
|
||||
self.setLayout(self.form_layout)
|
||||
|
||||
def on_language_changed(self, language: str):
|
||||
self.transcription_options.language = language
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_model_changed(self, model: Model):
|
||||
self.transcription_options.model = model
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_task_changed(self, task: Task):
|
||||
self.transcription_options.task = task
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
@ -1032,6 +1167,23 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.transcription_options = transcription_options
|
||||
self.transcription_options_changed.emit(transcription_options)
|
||||
|
||||
def on_model_type_changed(self, text: str):
|
||||
model_type = ModelType(text)
|
||||
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE)
|
||||
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
|
||||
(model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP))
|
||||
self.transcription_options.model_type = model_type
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_whisper_model_size_changed(self, text: str):
|
||||
model_size = WhisperModelSize(text.lower())
|
||||
self.transcription_options.model.whisper_model_size = model_size
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_hugging_face_model_changed(self, model: str):
|
||||
self.transcription_options.hugging_face_model = model
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
|
||||
class MenuBar(QMenuBar):
|
||||
import_action_triggered = pyqtSignal()
|
||||
|
@ -1080,7 +1232,6 @@ class AdvancedSettingsDialog(QDialog):
|
|||
|
||||
self.transcription_options = transcription_options
|
||||
|
||||
self.setFixedSize(400, 180)
|
||||
self.setWindowTitle('Advanced Settings')
|
||||
|
||||
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
|
||||
|
@ -1091,27 +1242,28 @@ class AdvancedSettingsDialog(QDialog):
|
|||
|
||||
default_temperature_text = ', '.join(
|
||||
[str(temp) for temp in transcription_options.temperature])
|
||||
self.temperature_line_edit = QLineEdit(default_temperature_text, self)
|
||||
self.temperature_line_edit = LineEdit(default_temperature_text, self)
|
||||
self.temperature_line_edit.setPlaceholderText(
|
||||
'Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"')
|
||||
self.temperature_line_edit.setMinimumWidth(170)
|
||||
self.temperature_line_edit.textChanged.connect(
|
||||
self.on_temperature_changed)
|
||||
self.temperature_line_edit.setValidator(TemperatureValidator(self))
|
||||
self.temperature_line_edit.setDisabled(
|
||||
transcription_options.model.is_whisper_cpp())
|
||||
self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER)
|
||||
|
||||
self.initial_prompt_text_edit = QPlainTextEdit(
|
||||
transcription_options.initial_prompt, self)
|
||||
self.initial_prompt_text_edit.textChanged.connect(
|
||||
self.on_initial_prompt_changed)
|
||||
self.initial_prompt_text_edit.setDisabled(
|
||||
transcription_options.model.is_whisper_cpp())
|
||||
self.initial_prompt_text_edit.setEnabled(
|
||||
transcription_options.model.model_type == ModelType.WHISPER)
|
||||
|
||||
layout.addRow('Temperature:', self.temperature_line_edit)
|
||||
layout.addRow('Initial Prompt:', self.initial_prompt_text_edit)
|
||||
layout.addWidget(button_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setFixedSize(self.sizeHint())
|
||||
|
||||
def on_temperature_changed(self, text: str):
|
||||
try:
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import enum
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
@ -9,9 +11,31 @@ import whisper
|
|||
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
from buzz.transcriber import Model
|
||||
from buzz import transformers_whisper
|
||||
|
||||
MODELS_SHA256 = {
|
||||
|
||||
class WhisperModelSize(enum.Enum):
|
||||
TINY = 'tiny'
|
||||
BASE = 'base'
|
||||
SMALL = 'small'
|
||||
MEDIUM = 'medium'
|
||||
LARGE = 'large'
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
WHISPER = 'Whisper'
|
||||
WHISPER_CPP = 'Whisper.cpp'
|
||||
HUGGING_FACE = 'Hugging Face'
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TranscriptionModel:
|
||||
model_type: ModelType = ModelType.WHISPER
|
||||
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY
|
||||
hugging_face_model_id: Optional[str] = None
|
||||
|
||||
|
||||
WHISPER_CPP_MODELS_SHA256 = {
|
||||
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
|
||||
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
|
||||
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
|
||||
|
@ -20,53 +44,85 @@ MODELS_SHA256 = {
|
|||
}
|
||||
|
||||
|
||||
def get_hugging_face_dataset_file_url(author: str, repository_name: str, filename: str):
|
||||
return f'https://huggingface.co/datasets/{author}/{repository_name}/resolve/main/{filename}'
|
||||
|
||||
|
||||
class ModelLoader(QObject):
|
||||
progress = pyqtSignal(tuple) # (current, total)
|
||||
finished = pyqtSignal(str)
|
||||
error = pyqtSignal(str)
|
||||
stopped = False
|
||||
|
||||
def __init__(self, model: Model, parent: Optional['QObject'] = None) -> None:
|
||||
def __init__(self, model: TranscriptionModel, parent: Optional['QObject'] = None) -> None:
|
||||
super().__init__(parent)
|
||||
self.name = model.model_name()
|
||||
self.use_whisper_cpp = model.is_whisper_cpp()
|
||||
self.model_type = model.model_type
|
||||
self.whisper_model_size = model.whisper_model_size
|
||||
self.hugging_face_model_id = model.hugging_face_model_id
|
||||
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
if self.model_type == ModelType.WHISPER_CPP:
|
||||
root_dir = user_cache_dir('Buzz')
|
||||
model_name = self.whisper_model_size.value
|
||||
url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp',
|
||||
filename=f'ggml-{model_name}.bin')
|
||||
file_path = os.path.join(root_dir, f'ggml-model-whisper-{model_name}.bin')
|
||||
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
|
||||
self.download_model(url, file_path, expected_sha256)
|
||||
return
|
||||
|
||||
if self.model_type == ModelType.WHISPER:
|
||||
root_dir = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
model_name = self.whisper_model_size.value
|
||||
url = whisper._MODELS[model_name]
|
||||
file_path = os.path.join(root_dir, os.path.basename(url))
|
||||
expected_sha256 = url.split('/')[-2]
|
||||
self.download_model(url, file_path, expected_sha256)
|
||||
return
|
||||
|
||||
if self.model_type == ModelType.HUGGING_FACE:
|
||||
self.progress.emit((0, 100))
|
||||
|
||||
try:
|
||||
# Loads the model from cache or download if not in cache
|
||||
transformers_whisper.load_model(self.hugging_face_model_id)
|
||||
except (FileNotFoundError, EnvironmentError) as exception:
|
||||
self.error.emit(f'{exception}')
|
||||
return
|
||||
|
||||
self.progress.emit((100, 100))
|
||||
self.finished.emit(self.hugging_face_model_id)
|
||||
return
|
||||
|
||||
def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]):
|
||||
try:
|
||||
if self.use_whisper_cpp:
|
||||
root = user_cache_dir('Buzz')
|
||||
url = f'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-{self.name}.bin'
|
||||
model_path = os.path.join(root, f'ggml-model-whisper-{self.name}.bin')
|
||||
else:
|
||||
root = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
url = whisper._MODELS[self.name]
|
||||
model_path = os.path.join(root, os.path.basename(url))
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
if os.path.exists(model_path) and not os.path.isfile(model_path):
|
||||
if os.path.exists(file_path) and not os.path.isfile(file_path):
|
||||
raise RuntimeError(
|
||||
f"{model_path} exists and is not a regular file")
|
||||
f"{file_path} exists and is not a regular file")
|
||||
|
||||
expected_sha256 = MODELS_SHA256[self.name] if self.use_whisper_cpp else url.split(
|
||||
"/")[-2]
|
||||
if os.path.isfile(model_path):
|
||||
model_bytes = open(model_path, "rb").read()
|
||||
if os.path.isfile(file_path):
|
||||
if expected_sha256 is None:
|
||||
self.finished.emit(file_path)
|
||||
return
|
||||
|
||||
model_bytes = open(file_path, "rb").read()
|
||||
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
|
||||
if model_sha256 == expected_sha256:
|
||||
self.finished.emit(model_path)
|
||||
self.finished.emit(file_path)
|
||||
return
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(file_path, 'wb') as output:
|
||||
source.raise_for_status()
|
||||
total_size = float(source.headers.get('Content-Length', 0))
|
||||
current = 0.0
|
||||
|
@ -78,12 +134,14 @@ class ModelLoader(QObject):
|
|||
current += len(chunk)
|
||||
self.progress.emit((current, total_size))
|
||||
|
||||
model_bytes = open(model_path, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.")
|
||||
if expected_sha256 is not None:
|
||||
model_bytes = open(file_path, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
|
||||
"model.")
|
||||
|
||||
self.finished.emit(model_path)
|
||||
self.finished.emit(file_path)
|
||||
except RuntimeError as exc:
|
||||
self.error.emit(str(exc))
|
||||
logging.exception('')
|
||||
|
|
|
@ -15,7 +15,6 @@ from dataclasses import dataclass, field
|
|||
from multiprocessing.connection import Connection
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
import typing
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
|
@ -25,7 +24,9 @@ import whisper
|
|||
from PyQt6.QtCore import QObject, QProcess, pyqtSignal, pyqtSlot, QThread
|
||||
from sounddevice import PortAudioError
|
||||
|
||||
from . import transformers_whisper
|
||||
from .conn import pipe_stderr
|
||||
from .model_loader import TranscriptionModel, ModelType
|
||||
|
||||
# Catch exception from whisper.dll not getting loaded.
|
||||
# TODO: Remove flag and try-except when issue with loading
|
||||
|
@ -53,32 +54,11 @@ class Segment:
|
|||
text: str
|
||||
|
||||
|
||||
class Model(enum.Enum):
|
||||
WHISPER_TINY = 'Whisper - Tiny'
|
||||
WHISPER_BASE = 'Whisper - Base'
|
||||
WHISPER_SMALL = 'Whisper - Small'
|
||||
WHISPER_MEDIUM = 'Whisper - Medium'
|
||||
WHISPER_LARGE = 'Whisper - Large'
|
||||
WHISPER_CPP_TINY = 'Whisper.cpp - Tiny'
|
||||
WHISPER_CPP_BASE = 'Whisper.cpp - Base'
|
||||
WHISPER_CPP_SMALL = 'Whisper.cpp - Small'
|
||||
WHISPER_CPP_MEDIUM = 'Whisper.cpp - Medium'
|
||||
WHISPER_CPP_LARGE = 'Whisper.cpp - Large'
|
||||
|
||||
def is_whisper_cpp(self) -> bool:
|
||||
model_type, _ = self.value.split(' - ')
|
||||
return model_type == 'Whisper.cpp'
|
||||
|
||||
def model_name(self) -> str:
|
||||
_, model_name = self.value.split(' - ')
|
||||
return model_name.lower()
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TranscriptionOptions:
|
||||
language: Optional[str] = None
|
||||
task: Task = Task.TRANSCRIBE
|
||||
model: Model = Model.WHISPER_TINY
|
||||
model: TranscriptionModel = TranscriptionModel()
|
||||
word_level_timings: bool = False
|
||||
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
|
||||
initial_prompt: str = ''
|
||||
|
@ -373,7 +353,7 @@ class WhisperFileTranscriber(QObject):
|
|||
|
||||
current_process: multiprocessing.Process
|
||||
progress = pyqtSignal(tuple) # (current, total)
|
||||
completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
|
||||
completed = pyqtSignal(list) # List[Segment]
|
||||
error = pyqtSignal(str)
|
||||
running = False
|
||||
read_line_thread: Optional[Thread] = None
|
||||
|
@ -390,6 +370,8 @@ class WhisperFileTranscriber(QObject):
|
|||
self.temperature = task.transcription_options.temperature
|
||||
self.initial_prompt = task.transcription_options.initial_prompt
|
||||
self.model_path = task.model_path
|
||||
self.transcription_options = task.transcription_options
|
||||
self.transcription_task = task
|
||||
self.segments = []
|
||||
|
||||
@pyqtSlot()
|
||||
|
@ -406,13 +388,8 @@ class WhisperFileTranscriber(QObject):
|
|||
|
||||
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
self.current_process = multiprocessing.Process(
|
||||
target=transcribe_whisper,
|
||||
args=(
|
||||
send_pipe, model_path, self.file_path,
|
||||
self.language, self.task, self.word_level_timings,
|
||||
self.temperature, self.initial_prompt
|
||||
))
|
||||
self.current_process = multiprocessing.Process(target=transcribe_whisper,
|
||||
args=(send_pipe, self.transcription_task))
|
||||
self.current_process.start()
|
||||
|
||||
self.read_line_thread = Thread(
|
||||
|
@ -430,8 +407,9 @@ class WhisperFileTranscriber(QObject):
|
|||
|
||||
self.read_line_thread.join()
|
||||
|
||||
if self.current_process.exitcode != 0:
|
||||
self.completed.emit((self.current_process.exitcode, []))
|
||||
# TODO: fix error handling when process crashes
|
||||
if self.current_process.exitcode != 0 and self.current_process.exitcode is not None:
|
||||
self.completed.emit([])
|
||||
|
||||
self.running = False
|
||||
|
||||
|
@ -458,7 +436,7 @@ class WhisperFileTranscriber(QObject):
|
|||
) for segment in segments_dict]
|
||||
self.current_process.join()
|
||||
# TODO: move this back to the parent thread
|
||||
self.completed.emit((self.current_process.exitcode, segments))
|
||||
self.completed.emit(segments)
|
||||
else:
|
||||
try:
|
||||
progress = int(line.split('|')[0].strip().strip('%'))
|
||||
|
@ -467,26 +445,30 @@ class WhisperFileTranscriber(QObject):
|
|||
continue
|
||||
|
||||
|
||||
def transcribe_whisper(
|
||||
stderr_conn: Connection, model_path: str, file_path: str,
|
||||
language: Optional[str], task: Task,
|
||||
word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str):
|
||||
def transcribe_whisper(stderr_conn: Connection, task: FileTranscriptionTask):
|
||||
with pipe_stderr(stderr_conn):
|
||||
model = whisper.load_model(model_path)
|
||||
|
||||
if word_level_timings:
|
||||
stable_whisper.modify_model(model)
|
||||
result = model.transcribe(
|
||||
audio=file_path, language=language,
|
||||
task=task.value, temperature=temperature,
|
||||
initial_prompt=initial_prompt, pbar=True)
|
||||
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
|
||||
model = transformers_whisper.load_model(task.model_path)
|
||||
language = task.transcription_options.language if task.transcription_options.language is not None else 'en'
|
||||
result = model.transcribe(audio_path=task.file_path, language=language,
|
||||
task=task.transcription_options.task.value, verbose=False)
|
||||
whisper_segments = result.get('segments')
|
||||
else:
|
||||
result = model.transcribe(
|
||||
audio=file_path, language=language, task=task.value, temperature=temperature,
|
||||
initial_prompt=initial_prompt, verbose=False)
|
||||
|
||||
whisper_segments = stable_whisper.group_word_timestamps(
|
||||
result) if word_level_timings else result.get('segments')
|
||||
model = whisper.load_model(task.model_path)
|
||||
if task.transcription_options.word_level_timings:
|
||||
stable_whisper.modify_model(model)
|
||||
result = model.transcribe(
|
||||
audio=task.file_path, language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value, temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt, pbar=True)
|
||||
whisper_segments = stable_whisper.group_word_timestamps(result)
|
||||
else:
|
||||
result = model.transcribe(
|
||||
audio=task.file_path, language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value,
|
||||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt, verbose=False)
|
||||
whisper_segments = result.get('segments')
|
||||
|
||||
segments = [
|
||||
Segment(
|
||||
|
@ -638,7 +620,7 @@ class FileTranscriberQueueWorker(QObject):
|
|||
self.completed.emit()
|
||||
return
|
||||
|
||||
if self.current_task.transcription_options.model.is_whisper_cpp():
|
||||
if self.current_task.transcription_options.model.model_type == ModelType.WHISPER_CPP:
|
||||
self.current_transcriber = WhisperCppFileTranscriber(
|
||||
task=self.current_task)
|
||||
else:
|
||||
|
@ -688,10 +670,9 @@ class FileTranscriberQueueWorker(QObject):
|
|||
self.current_task.fraction_completed = progress[0] / progress[1]
|
||||
self.task_updated.emit(self.current_task)
|
||||
|
||||
@pyqtSlot(tuple)
|
||||
def on_task_completed(self, result: Tuple[int, List[Segment]]):
|
||||
@pyqtSlot(list)
|
||||
def on_task_completed(self, segments: List[Segment]):
|
||||
if self.current_task is not None:
|
||||
_, segments = result
|
||||
self.current_task.status = FileTranscriptionTask.Status.COMPLETED
|
||||
self.current_task.segments = segments
|
||||
self.task_updated.emit(self.current_task)
|
||||
|
|
56
buzz/transformers_whisper.py
Normal file
56
buzz/transformers_whisper.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import whisper
|
||||
from tqdm import tqdm
|
||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
|
||||
def load_model(model_name_or_path: str):
|
||||
processor = WhisperProcessor.from_pretrained(model_name_or_path)
|
||||
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path)
|
||||
return TransformersWhisper(processor, model)
|
||||
|
||||
|
||||
class TransformersWhisper:
|
||||
SAMPLE_RATE = whisper.audio.SAMPLE_RATE
|
||||
N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES
|
||||
|
||||
def __init__(self, processor: WhisperProcessor, model: WhisperForConditionalGeneration):
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
|
||||
# Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and
|
||||
# timestamps are available. See: https://github.com/huggingface/transformers/issues/19887,
|
||||
# https://github.com/huggingface/transformers/pull/20620.
|
||||
def transcribe(self, audio_path: str, language: str, task: str, verbose: Optional[bool] = None):
|
||||
audio: np.ndarray = whisper.load_audio(audio_path, sr=self.SAMPLE_RATE)
|
||||
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(task=task, language=language)
|
||||
|
||||
segments = []
|
||||
all_predicted_ids = []
|
||||
|
||||
num_samples = audio.size
|
||||
seek = 0
|
||||
with tqdm(total=num_samples, unit='samples', disable=verbose is not False) as progress_bar:
|
||||
while seek < num_samples:
|
||||
chunk = audio[seek: seek + self.N_SAMPLES_IN_CHUNK]
|
||||
input_features = self.processor(chunk, return_tensors="pt",
|
||||
sampling_rate=self.SAMPLE_RATE).input_features
|
||||
predicted_ids = self.model.generate(input_features)
|
||||
all_predicted_ids.extend(predicted_ids)
|
||||
text: str = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
||||
if text.strip() != '':
|
||||
segments.append({
|
||||
'start': seek / self.SAMPLE_RATE,
|
||||
'end': min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) / self.SAMPLE_RATE,
|
||||
'text': text
|
||||
})
|
||||
|
||||
progress_bar.update(min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek)
|
||||
seek += self.N_SAMPLES_IN_CHUNK
|
||||
|
||||
return {
|
||||
'text': self.processor.batch_decode(all_predicted_ids, skip_special_tokens=True)[0],
|
||||
'segments': segments
|
||||
}
|
|
@ -6,19 +6,19 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
import sounddevice
|
||||
from PyQt6.QtCore import QSize, Qt
|
||||
from PyQt6.QtGui import QValidator
|
||||
from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget
|
||||
from PyQt6.QtGui import QValidator, QKeyEvent
|
||||
from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget, QApplication
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
from buzz.gui import (AboutDialog, AdvancedSettingsDialog, Application,
|
||||
AudioDevicesComboBox, DownloadModelProgressDialog,
|
||||
FileTranscriberWidget, LanguagesComboBox, MainWindow,
|
||||
ModelComboBox, RecordingTranscriberWidget,
|
||||
RecordingTranscriberWidget,
|
||||
TemperatureValidator, TextDisplayBox,
|
||||
TranscriptionTasksTableWidget, TranscriptionViewerWidget)
|
||||
TranscriptionTasksTableWidget, TranscriptionViewerWidget, HuggingFaceSearchLineEdit)
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
|
||||
Model, Segment, TranscriptionOptions)
|
||||
Segment, TranscriptionOptions)
|
||||
|
||||
|
||||
class TestApplication:
|
||||
|
@ -48,20 +48,6 @@ class TestLanguagesComboBox:
|
|||
assert languages_combo_box.currentText() == 'Detect Language'
|
||||
|
||||
|
||||
class TestModelComboBox:
|
||||
model_combo_box = ModelComboBox(
|
||||
default_model=Model.WHISPER_CPP_BASE, parent=None)
|
||||
|
||||
def test_should_show_qualities(self):
|
||||
assert self.model_combo_box.itemText(0) == 'Whisper - Tiny'
|
||||
assert self.model_combo_box.itemText(1) == 'Whisper - Base'
|
||||
assert self.model_combo_box.itemText(2) == 'Whisper - Small'
|
||||
assert self.model_combo_box.itemText(3) == 'Whisper - Medium'
|
||||
|
||||
def test_should_select_default_model(self):
|
||||
assert self.model_combo_box.currentText() == 'Whisper.cpp - Base'
|
||||
|
||||
|
||||
class TestAudioDevicesComboBox:
|
||||
mock_query_devices = [
|
||||
{'name': 'Background Music', 'index': 0, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2,
|
||||
|
@ -198,10 +184,9 @@ class TestFileTranscriberWidget:
|
|||
widget = FileTranscriberWidget(
|
||||
file_paths=['testdata/whisper-french.mp3'], parent=None)
|
||||
|
||||
def test_should_set_window_title_and_size(self, qtbot: QtBot):
|
||||
def test_should_set_window_title(self, qtbot: QtBot):
|
||||
qtbot.addWidget(self.widget)
|
||||
assert self.widget.windowTitle() == 'whisper-french.mp3'
|
||||
assert self.widget.size() == QSize(420, 270)
|
||||
|
||||
def test_should_emit_triggered_event(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
|
@ -217,7 +202,6 @@ class TestFileTranscriberWidget:
|
|||
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
|
||||
0][0]
|
||||
assert transcription_options.language is None
|
||||
assert transcription_options.model == Model.WHISPER_TINY
|
||||
assert file_transcription_options.file_paths == [
|
||||
'testdata/whisper-french.mp3']
|
||||
assert len(model_path) > 0
|
||||
|
@ -232,8 +216,7 @@ class TestAboutDialog:
|
|||
class TestAdvancedSettingsDialog:
|
||||
def test_should_update_advanced_settings(self, qtbot: QtBot):
|
||||
dialog = AdvancedSettingsDialog(
|
||||
transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt',
|
||||
model=Model.WHISPER_CPP_BASE))
|
||||
transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt'))
|
||||
qtbot.add_widget(dialog)
|
||||
|
||||
transcription_options_mock = Mock()
|
||||
|
@ -333,7 +316,50 @@ class TestTranscriptionTasksTableWidget:
|
|||
class TestRecordingTranscriberWidget:
|
||||
widget = RecordingTranscriberWidget()
|
||||
|
||||
def test_should_set_window_title_and_size(self, qtbot: QtBot):
|
||||
def test_should_set_window_title(self, qtbot: QtBot):
|
||||
qtbot.add_widget(self.widget)
|
||||
assert self.widget.windowTitle() == 'Live Recording'
|
||||
assert self.widget.size() == QSize(400, 520)
|
||||
|
||||
|
||||
class TestHuggingFaceSearchLineEdit:
|
||||
def test_should_update_selected_model_on_type(self, qtbot: QtBot):
|
||||
widget = HuggingFaceSearchLineEdit()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
mock_model_selected = Mock()
|
||||
widget.model_selected.connect(mock_model_selected)
|
||||
|
||||
self._set_text_and_wait_response(qtbot, widget)
|
||||
mock_model_selected.assert_called_with('openai/whisper-tiny')
|
||||
|
||||
def test_should_show_list_of_models(self, qtbot: QtBot):
|
||||
widget = HuggingFaceSearchLineEdit()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
self._set_text_and_wait_response(qtbot, widget)
|
||||
|
||||
assert widget.popup.count() > 0
|
||||
assert 'openai/whisper-tiny' in widget.popup.item(0).text()
|
||||
|
||||
def test_should_select_model_from_list(self, qtbot: QtBot):
|
||||
widget = HuggingFaceSearchLineEdit()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
mock_model_selected = Mock()
|
||||
widget.model_selected.connect(mock_model_selected)
|
||||
|
||||
self._set_text_and_wait_response(qtbot, widget)
|
||||
|
||||
# press down arrow and enter to select next item
|
||||
QApplication.sendEvent(widget.popup,
|
||||
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier))
|
||||
QApplication.sendEvent(widget.popup,
|
||||
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Enter, Qt.KeyboardModifier.NoModifier))
|
||||
|
||||
mock_model_selected.assert_called_with('openai/whisper-tiny.en')
|
||||
|
||||
@staticmethod
|
||||
def _set_text_and_wait_response(qtbot: QtBot, widget: HuggingFaceSearchLineEdit):
|
||||
with qtbot.wait_signal(widget.network_manager.finished, timeout=30 * 1000):
|
||||
widget.setText('openai/whisper-tiny')
|
||||
widget.textEdited.emit('openai/whisper-tiny')
|
||||
|
|
14
tests/model_loader.py
Normal file
14
tests/model_loader.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from buzz.model_loader import ModelLoader, TranscriptionModel
|
||||
|
||||
|
||||
def get_model_path(transcription_model: TranscriptionModel) -> str:
|
||||
model_loader = ModelLoader(model=transcription_model)
|
||||
model_path = ''
|
||||
|
||||
def on_load_model(path: str):
|
||||
nonlocal model_path
|
||||
model_path = path
|
||||
|
||||
model_loader.finished.connect(on_load_model)
|
||||
model_loader.run()
|
||||
return model_path
|
|
@ -8,30 +8,18 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.model_loader import ModelLoader
|
||||
from buzz.transcriber import (FileTranscriberQueueWorker, FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber, Segment, Task,
|
||||
WhisperCpp, WhisperCppFileTranscriber,
|
||||
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber,
|
||||
Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
|
||||
WhisperFileTranscriber,
|
||||
get_default_output_file_path, to_timestamp,
|
||||
whisper_cpp_params, write_output, TranscriptionOptions, Model)
|
||||
|
||||
|
||||
def get_model_path(model: Model) -> str:
|
||||
model_loader = ModelLoader(model=model)
|
||||
model_path = ''
|
||||
|
||||
def on_load_model(path: str):
|
||||
nonlocal model_path
|
||||
model_path = path
|
||||
|
||||
model_loader.finished.connect(on_load_model)
|
||||
model_loader.run()
|
||||
return model_path
|
||||
whisper_cpp_params, write_output, TranscriptionOptions)
|
||||
from tests.model_loader import get_model_path
|
||||
|
||||
|
||||
class TestRecordingTranscriber:
|
||||
def test_transcriber(self):
|
||||
model_path = get_model_path(Model.WHISPER_CPP_TINY)
|
||||
model_path = get_model_path(transcription_model=TranscriptionModel())
|
||||
transcriber = RecordingTranscriber(
|
||||
model_path=model_path, use_whisper_cpp=True, language='en',
|
||||
task=Task.TRANSCRIBE)
|
||||
|
@ -49,11 +37,15 @@ class TestWhisperCppFileTranscriber:
|
|||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3'])
|
||||
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings)
|
||||
word_level_timings=word_level_timings,
|
||||
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY))
|
||||
|
||||
model_path = get_model_path(Model.WHISPER_CPP_TINY)
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
transcriber = WhisperCppFileTranscriber(
|
||||
task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3', transcription_options=transcription_options, file_transcription_options=file_transcription_options, model_path=model_path))
|
||||
task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3',
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options, model_path=model_path))
|
||||
mock_progress = Mock()
|
||||
mock_completed = Mock()
|
||||
transcriber.progress.connect(mock_progress)
|
||||
|
@ -81,44 +73,51 @@ class TestWhisperFileTranscriber:
|
|||
assert srt.endswith('.srt')
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'word_level_timings,expected_segments',
|
||||
'word_level_timings,expected_segments,model,check_progress',
|
||||
[
|
||||
(False, [
|
||||
Segment(
|
||||
0, 6560,
|
||||
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances'),
|
||||
]),
|
||||
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')])
|
||||
(False, [Segment(0, 6560,
|
||||
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances')],
|
||||
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
|
||||
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')],
|
||||
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
|
||||
(False, [Segment(0, 8517,
|
||||
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances '
|
||||
'et des apprenances de français.')],
|
||||
TranscriptionModel(model_type=ModelType.HUGGING_FACE,
|
||||
hugging_face_model_id='openai/whisper-tiny'), False)
|
||||
])
|
||||
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
|
||||
model_path = get_model_path(Model.WHISPER_TINY)
|
||||
|
||||
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment],
|
||||
model: TranscriptionModel, check_progress):
|
||||
mock_progress = Mock()
|
||||
mock_completed = Mock()
|
||||
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings)
|
||||
word_level_timings=word_level_timings,
|
||||
model=model)
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3'])
|
||||
|
||||
transcriber = WhisperFileTranscriber(
|
||||
task=FileTranscriptionTask(transcription_options=transcription_options, file_transcription_options=file_transcription_options, file_path='testdata/whisper-french.mp3', model_path=model_path))
|
||||
task=FileTranscriptionTask(transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
file_path='testdata/whisper-french.mp3', model_path=model_path))
|
||||
transcriber.progress.connect(mock_progress)
|
||||
transcriber.completed.connect(mock_completed)
|
||||
with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
|
||||
transcriber.run()
|
||||
|
||||
# Reports progress at 0, 0<progress<100, and 100
|
||||
assert any(
|
||||
[call_args.args[0] == (0, 100) for call_args in mock_progress.call_args_list])
|
||||
assert any(
|
||||
[call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
|
||||
assert any(
|
||||
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in
|
||||
mock_progress.call_args_list])
|
||||
if check_progress:
|
||||
# Reports progress at 0, 0<progress<100, and 100
|
||||
assert any(
|
||||
[call_args.args[0] == (0, 100) for call_args in mock_progress.call_args_list])
|
||||
assert any(
|
||||
[call_args.args[0] == (100, 100) for call_args in mock_progress.call_args_list])
|
||||
assert any(
|
||||
[(0 < call_args.args[0][0] < 100) and (call_args.args[0][1] == 100) for call_args in
|
||||
mock_progress.call_args_list])
|
||||
|
||||
mock_completed.assert_called()
|
||||
exit_code, segments = mock_completed.call_args[0][0]
|
||||
assert exit_code is 0
|
||||
segments = mock_completed.call_args[0][0]
|
||||
for (i, expected_segment) in enumerate(expected_segments):
|
||||
assert segments[i] == expected_segment
|
||||
|
||||
|
@ -128,14 +127,17 @@ class TestWhisperFileTranscriber:
|
|||
if os.path.exists(output_file_path):
|
||||
os.remove(output_file_path)
|
||||
|
||||
model_path = get_model_path(Model.WHISPER_TINY)
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3'])
|
||||
transcription_options = TranscriptionOptions(
|
||||
language='fr', task=Task.TRANSCRIBE, word_level_timings=False)
|
||||
language='fr', task=Task.TRANSCRIBE, word_level_timings=False,
|
||||
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
|
||||
transcriber = WhisperFileTranscriber(
|
||||
task=FileTranscriptionTask(model_path=model_path, transcription_options=transcription_options, file_transcription_options=file_transcription_options, file_path='testdata/whisper-french.mp3'))
|
||||
task=FileTranscriptionTask(model_path=model_path, transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
file_path='testdata/whisper-french.mp3'))
|
||||
transcriber.run()
|
||||
time.sleep(1)
|
||||
transcriber.stop()
|
||||
|
@ -152,7 +154,9 @@ class TestToTimestamp:
|
|||
|
||||
class TestWhisperCpp:
|
||||
def test_transcribe(self):
|
||||
model_path = get_model_path(Model.WHISPER_CPP_TINY)
|
||||
transcription_options = TranscriptionOptions(
|
||||
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
|
||||
whisper_cpp = WhisperCpp(model=model_path)
|
||||
params = whisper_cpp_params(
|
||||
|
@ -168,8 +172,8 @@ class TestWhisperCpp:
|
|||
[
|
||||
(OutputFormat.TXT, 'Bien venue dans\n'),
|
||||
(
|
||||
OutputFormat.SRT,
|
||||
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
OutputFormat.SRT,
|
||||
'1\n00:00:00.040 --> 00:00:00.299\nBien\n\n2\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
(OutputFormat.VTT,
|
||||
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
])
|
||||
|
|
10
tests/transformers_whisper_test.py
Normal file
10
tests/transformers_whisper_test.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from buzz.transformers_whisper import load_model
|
||||
|
||||
|
||||
class TestTransformersWhisper:
|
||||
def test_should_transcribe(self):
|
||||
model = load_model('openai/whisper-tiny')
|
||||
result = model.transcribe(
|
||||
audio_path='testdata/whisper-french.mp3', language='fr', task='transcribe')
|
||||
|
||||
assert 'Bienvenue dans Passe' in result['text']
|
Loading…
Reference in a new issue