Replace models combo box with quality (#69)

This commit is contained in:
Chidi Williams 2022-10-15 14:47:29 +01:00 committed by GitHub
commit ec0aef39b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

78
gui.py
View file

@ -1,4 +1,5 @@
import enum
import logging
import os
import platform
from datetime import datetime
@ -102,25 +103,25 @@ class TasksComboBox(QComboBox):
self.taskChanged.emit(self.tasks[index])
class ModelsComboBox(QComboBox):
"""ModelsComboBox displays the list of available Whisper models for selection"""
modelNameChanged = pyqtSignal(str)
class Quality(enum.Enum):
LOW = 'low'
MEDIUM = 'medium'
HIGH = 'high'
def __init__(self, default_model_name: str, *args) -> None:
class QualityComboBox(QComboBox):
quality_changed = pyqtSignal(Quality)
def __init__(self, default_quality: Quality, *args) -> None:
super().__init__(*args)
self.models = whisper.available_models()
self.addItems(map(self.label, self.models))
self.qualities = [i for i in Quality]
self.addItems(
map(lambda quality: quality.value.title(), self.qualities))
self.currentIndexChanged.connect(self.on_index_changed)
self.setCurrentText(self.label(default_model_name))
self.setCurrentText(default_quality.value.title())
def on_index_changed(self, index: int):
self.modelNameChanged.emit(self.models[index])
def label(self, model_name: str):
name, lang = (model_name.split('.') + [None])[:2]
if lang:
return "%s (%s)" % (name.title(), lang.upper())
return name.title()
self.quality_changed.emit(self.qualities[index])
class DelaysComboBox(QComboBox):
@ -216,7 +217,6 @@ class DownloadModelProgressDialog(QProgressDialog):
class TranscriberProgressDialog(QProgressDialog):
total_size: int
short_file_path: str
start_time: datetime
@ -302,8 +302,16 @@ class TimerLabel(QLabel):
seconds_passed // 60, seconds_passed % 60))
def get_model_name(quality: Quality, language: str) -> str:
return {
Quality.LOW: ('tiny', 'tiny.en'),
Quality.MEDIUM: ('base', 'base.en'),
Quality.HIGH: ('small', 'small.en'),
}[quality][1 if language == 'en' else 0]
class FileTranscriberWidget(QWidget):
selected_model_name = 'tiny'
selected_quality = Quality.LOW
selected_language = 'en'
selected_task = Task.TRANSCRIBE
progress_dialog: Optional[DownloadModelProgressDialog] = None
@ -317,9 +325,9 @@ class FileTranscriberWidget(QWidget):
self.file_path = file_path
self.models_combo_box = ModelsComboBox(
default_model_name=self.selected_model_name)
self.models_combo_box.modelNameChanged.connect(self.on_model_changed)
self.quality_combo_box = QualityComboBox(
default_quality=self.selected_quality)
self.quality_combo_box.quality_changed.connect(self.on_quality_changed)
self.languages_combo_box = LanguagesComboBox(
default_language=self.selected_language)
@ -337,7 +345,7 @@ class FileTranscriberWidget(QWidget):
grid = (
((0, 5, FormLabel('Task:')), (5, 7, self.tasks_combo_box)),
((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)),
((0, 5, FormLabel('Model:')), (5, 7, self.models_combo_box)),
((0, 5, FormLabel('Quality:')), (5, 7, self.quality_combo_box)),
((9, 3, self.run_button),)
)
@ -350,8 +358,8 @@ class FileTranscriberWidget(QWidget):
self.transcribe_progress.connect(self.handle_transcribe_progress)
def on_model_changed(self, model_name: str):
self.selected_model_name = model_name
def on_quality_changed(self, quality: Quality):
self.selected_quality = quality
def on_language_changed(self, language: str):
self.selected_language = language
@ -369,8 +377,11 @@ class FileTranscriberWidget(QWidget):
return
self.run_button.setDisabled(True)
model_name = get_model_name(
self.selected_quality, self.selected_language)
logging.debug(f'Loading model: {model_name}')
model = _whisper.load_model(
self.selected_model_name, on_download_model_chunk=self.on_download_model_progress)
model_name, on_download_model_chunk=self.on_download_model_progress)
self.file_transcriber = FileTranscriber(
model=model, file_path=self.file_path,
@ -412,7 +423,7 @@ class FileTranscriberWidget(QWidget):
class RecordingTranscriberWidget(QWidget):
current_status = RecordButton.Status.STOPPED
selected_model_name = 'tiny'
selected_quality = Quality.LOW
selected_language = 'en'
selected_device_id: Optional[int]
selected_delay = 10
@ -424,9 +435,9 @@ class RecordingTranscriberWidget(QWidget):
layout = QGridLayout()
self.models_combo_box = ModelsComboBox(
default_model_name=self.selected_model_name)
self.models_combo_box.modelNameChanged.connect(self.on_model_changed)
self.quality_combo_box = QualityComboBox(
default_quality=self.selected_quality)
self.quality_combo_box.quality_changed.connect(self.on_quality_changed)
self.languages_combo_box = LanguagesComboBox(
default_language=self.selected_language)
@ -453,9 +464,9 @@ class RecordingTranscriberWidget(QWidget):
self.text_box = TextDisplayBox()
grid = (
((0, 5, FormLabel('Model:')), (5, 7, self.models_combo_box)),
((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)),
((0, 5, FormLabel('Task:')), (5, 7, self.tasks_combo_box)),
((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)),
((0, 5, FormLabel('Quality:')), (5, 7, self.quality_combo_box)),
((0, 5, FormLabel('Microphone:')),
(5, 7, self.audio_devices_combo_box)),
((0, 5, FormLabel('Delay:')), (5, 7, delays_combo_box)),
@ -489,8 +500,8 @@ class RecordingTranscriberWidget(QWidget):
else:
self.stop_recording()
def on_model_changed(self, model_name: str):
self.selected_model_name = model_name
def on_quality_changed(self, quality: Quality):
self.selected_quality = quality
def on_language_changed(self, language: str):
self.selected_language = language
@ -504,8 +515,11 @@ class RecordingTranscriberWidget(QWidget):
def start_recording(self):
self.record_button.setDisabled(True)
model_name = get_model_name(
self.selected_quality, self.selected_language)
logging.debug(f'Loading model: {model_name}')
model = _whisper.load_model(
self.selected_model_name, on_download_model_chunk=self.on_download_model_chunk)
model_name, on_download_model_chunk=self.on_download_model_chunk)
self.record_button.setDisabled(False)
# Clear text box placeholder because the first chunk takes a while to process