mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-15 15:15:49 +01:00
Replace models combo box with quality (#69)
This commit is contained in:
parent
5a17cf63fe
commit
ec0aef39b5
1 changed files with 46 additions and 32 deletions
78
gui.py
78
gui.py
|
|
@ -1,4 +1,5 @@
|
|||
import enum
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from datetime import datetime
|
||||
|
|
@ -102,25 +103,25 @@ class TasksComboBox(QComboBox):
|
|||
self.taskChanged.emit(self.tasks[index])
|
||||
|
||||
|
||||
class ModelsComboBox(QComboBox):
|
||||
"""ModelsComboBox displays the list of available Whisper models for selection"""
|
||||
modelNameChanged = pyqtSignal(str)
|
||||
class Quality(enum.Enum):
|
||||
LOW = 'low'
|
||||
MEDIUM = 'medium'
|
||||
HIGH = 'high'
|
||||
|
||||
def __init__(self, default_model_name: str, *args) -> None:
|
||||
|
||||
class QualityComboBox(QComboBox):
|
||||
quality_changed = pyqtSignal(Quality)
|
||||
|
||||
def __init__(self, default_quality: Quality, *args) -> None:
|
||||
super().__init__(*args)
|
||||
self.models = whisper.available_models()
|
||||
self.addItems(map(self.label, self.models))
|
||||
self.qualities = [i for i in Quality]
|
||||
self.addItems(
|
||||
map(lambda quality: quality.value.title(), self.qualities))
|
||||
self.currentIndexChanged.connect(self.on_index_changed)
|
||||
self.setCurrentText(self.label(default_model_name))
|
||||
self.setCurrentText(default_quality.value.title())
|
||||
|
||||
def on_index_changed(self, index: int):
|
||||
self.modelNameChanged.emit(self.models[index])
|
||||
|
||||
def label(self, model_name: str):
|
||||
name, lang = (model_name.split('.') + [None])[:2]
|
||||
if lang:
|
||||
return "%s (%s)" % (name.title(), lang.upper())
|
||||
return name.title()
|
||||
self.quality_changed.emit(self.qualities[index])
|
||||
|
||||
|
||||
class DelaysComboBox(QComboBox):
|
||||
|
|
@ -216,7 +217,6 @@ class DownloadModelProgressDialog(QProgressDialog):
|
|||
|
||||
|
||||
class TranscriberProgressDialog(QProgressDialog):
|
||||
total_size: int
|
||||
short_file_path: str
|
||||
start_time: datetime
|
||||
|
||||
|
|
@ -302,8 +302,16 @@ class TimerLabel(QLabel):
|
|||
seconds_passed // 60, seconds_passed % 60))
|
||||
|
||||
|
||||
def get_model_name(quality: Quality, language: str) -> str:
|
||||
return {
|
||||
Quality.LOW: ('tiny', 'tiny.en'),
|
||||
Quality.MEDIUM: ('base', 'base.en'),
|
||||
Quality.HIGH: ('small', 'small.en'),
|
||||
}[quality][1 if language == 'en' else 0]
|
||||
|
||||
|
||||
class FileTranscriberWidget(QWidget):
|
||||
selected_model_name = 'tiny'
|
||||
selected_quality = Quality.LOW
|
||||
selected_language = 'en'
|
||||
selected_task = Task.TRANSCRIBE
|
||||
progress_dialog: Optional[DownloadModelProgressDialog] = None
|
||||
|
|
@ -317,9 +325,9 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
self.file_path = file_path
|
||||
|
||||
self.models_combo_box = ModelsComboBox(
|
||||
default_model_name=self.selected_model_name)
|
||||
self.models_combo_box.modelNameChanged.connect(self.on_model_changed)
|
||||
self.quality_combo_box = QualityComboBox(
|
||||
default_quality=self.selected_quality)
|
||||
self.quality_combo_box.quality_changed.connect(self.on_quality_changed)
|
||||
|
||||
self.languages_combo_box = LanguagesComboBox(
|
||||
default_language=self.selected_language)
|
||||
|
|
@ -337,7 +345,7 @@ class FileTranscriberWidget(QWidget):
|
|||
grid = (
|
||||
((0, 5, FormLabel('Task:')), (5, 7, self.tasks_combo_box)),
|
||||
((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)),
|
||||
((0, 5, FormLabel('Model:')), (5, 7, self.models_combo_box)),
|
||||
((0, 5, FormLabel('Quality:')), (5, 7, self.quality_combo_box)),
|
||||
((9, 3, self.run_button),)
|
||||
)
|
||||
|
||||
|
|
@ -350,8 +358,8 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
self.transcribe_progress.connect(self.handle_transcribe_progress)
|
||||
|
||||
def on_model_changed(self, model_name: str):
|
||||
self.selected_model_name = model_name
|
||||
def on_quality_changed(self, quality: Quality):
|
||||
self.selected_quality = quality
|
||||
|
||||
def on_language_changed(self, language: str):
|
||||
self.selected_language = language
|
||||
|
|
@ -369,8 +377,11 @@ class FileTranscriberWidget(QWidget):
|
|||
return
|
||||
|
||||
self.run_button.setDisabled(True)
|
||||
model_name = get_model_name(
|
||||
self.selected_quality, self.selected_language)
|
||||
logging.debug(f'Loading model: {model_name}')
|
||||
model = _whisper.load_model(
|
||||
self.selected_model_name, on_download_model_chunk=self.on_download_model_progress)
|
||||
model_name, on_download_model_chunk=self.on_download_model_progress)
|
||||
|
||||
self.file_transcriber = FileTranscriber(
|
||||
model=model, file_path=self.file_path,
|
||||
|
|
@ -412,7 +423,7 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
class RecordingTranscriberWidget(QWidget):
|
||||
current_status = RecordButton.Status.STOPPED
|
||||
selected_model_name = 'tiny'
|
||||
selected_quality = Quality.LOW
|
||||
selected_language = 'en'
|
||||
selected_device_id: Optional[int]
|
||||
selected_delay = 10
|
||||
|
|
@ -424,9 +435,9 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
layout = QGridLayout()
|
||||
|
||||
self.models_combo_box = ModelsComboBox(
|
||||
default_model_name=self.selected_model_name)
|
||||
self.models_combo_box.modelNameChanged.connect(self.on_model_changed)
|
||||
self.quality_combo_box = QualityComboBox(
|
||||
default_quality=self.selected_quality)
|
||||
self.quality_combo_box.quality_changed.connect(self.on_quality_changed)
|
||||
|
||||
self.languages_combo_box = LanguagesComboBox(
|
||||
default_language=self.selected_language)
|
||||
|
|
@ -453,9 +464,9 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.text_box = TextDisplayBox()
|
||||
|
||||
grid = (
|
||||
((0, 5, FormLabel('Model:')), (5, 7, self.models_combo_box)),
|
||||
((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)),
|
||||
((0, 5, FormLabel('Task:')), (5, 7, self.tasks_combo_box)),
|
||||
((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)),
|
||||
((0, 5, FormLabel('Quality:')), (5, 7, self.quality_combo_box)),
|
||||
((0, 5, FormLabel('Microphone:')),
|
||||
(5, 7, self.audio_devices_combo_box)),
|
||||
((0, 5, FormLabel('Delay:')), (5, 7, delays_combo_box)),
|
||||
|
|
@ -489,8 +500,8 @@ class RecordingTranscriberWidget(QWidget):
|
|||
else:
|
||||
self.stop_recording()
|
||||
|
||||
def on_model_changed(self, model_name: str):
|
||||
self.selected_model_name = model_name
|
||||
def on_quality_changed(self, quality: Quality):
|
||||
self.selected_quality = quality
|
||||
|
||||
def on_language_changed(self, language: str):
|
||||
self.selected_language = language
|
||||
|
|
@ -504,8 +515,11 @@ class RecordingTranscriberWidget(QWidget):
|
|||
def start_recording(self):
|
||||
self.record_button.setDisabled(True)
|
||||
|
||||
model_name = get_model_name(
|
||||
self.selected_quality, self.selected_language)
|
||||
logging.debug(f'Loading model: {model_name}')
|
||||
model = _whisper.load_model(
|
||||
self.selected_model_name, on_download_model_chunk=self.on_download_model_chunk)
|
||||
model_name, on_download_model_chunk=self.on_download_model_chunk)
|
||||
|
||||
self.record_button.setDisabled(False)
|
||||
# Clear text box placeholder because the first chunk takes a while to process
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue