diff --git a/gui.py b/gui.py index e04c8e5a..a6dedb9e 100644 --- a/gui.py +++ b/gui.py @@ -1,4 +1,5 @@ import enum +import logging import os import platform from datetime import datetime @@ -102,25 +103,25 @@ class TasksComboBox(QComboBox): self.taskChanged.emit(self.tasks[index]) -class ModelsComboBox(QComboBox): - """ModelsComboBox displays the list of available Whisper models for selection""" - modelNameChanged = pyqtSignal(str) +class Quality(enum.Enum): + LOW = 'low' + MEDIUM = 'medium' + HIGH = 'high' - def __init__(self, default_model_name: str, *args) -> None: + +class QualityComboBox(QComboBox): + quality_changed = pyqtSignal(Quality) + + def __init__(self, default_quality: Quality, *args) -> None: super().__init__(*args) - self.models = whisper.available_models() - self.addItems(map(self.label, self.models)) + self.qualities = [i for i in Quality] + self.addItems( + map(lambda quality: quality.value.title(), self.qualities)) self.currentIndexChanged.connect(self.on_index_changed) - self.setCurrentText(self.label(default_model_name)) + self.setCurrentText(default_quality.value.title()) def on_index_changed(self, index: int): - self.modelNameChanged.emit(self.models[index]) - - def label(self, model_name: str): - name, lang = (model_name.split('.') + [None])[:2] - if lang: - return "%s (%s)" % (name.title(), lang.upper()) - return name.title() + self.quality_changed.emit(self.qualities[index]) class DelaysComboBox(QComboBox): @@ -216,7 +217,6 @@ class DownloadModelProgressDialog(QProgressDialog): class TranscriberProgressDialog(QProgressDialog): - total_size: int short_file_path: str start_time: datetime @@ -302,8 +302,16 @@ class TimerLabel(QLabel): seconds_passed // 60, seconds_passed % 60)) +def get_model_name(quality: Quality, language: str) -> str: + return { + Quality.LOW: ('tiny', 'tiny.en'), + Quality.MEDIUM: ('base', 'base.en'), + Quality.HIGH: ('small', 'small.en'), + }[quality][1 if language == 'en' else 0] + + class FileTranscriberWidget(QWidget): - selected_model_name = 'tiny' + selected_quality = Quality.LOW selected_language = 'en' selected_task = Task.TRANSCRIBE progress_dialog: Optional[DownloadModelProgressDialog] = None @@ -317,9 +325,9 @@ class FileTranscriberWidget(QWidget): self.file_path = file_path - self.models_combo_box = ModelsComboBox( - default_model_name=self.selected_model_name) - self.models_combo_box.modelNameChanged.connect(self.on_model_changed) + self.quality_combo_box = QualityComboBox( + default_quality=self.selected_quality) + self.quality_combo_box.quality_changed.connect(self.on_quality_changed) self.languages_combo_box = LanguagesComboBox( default_language=self.selected_language) @@ -337,7 +345,7 @@ class FileTranscriberWidget(QWidget): grid = ( ((0, 5, FormLabel('Task:')), (5, 7, self.tasks_combo_box)), ((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)), - ((0, 5, FormLabel('Model:')), (5, 7, self.models_combo_box)), + ((0, 5, FormLabel('Quality:')), (5, 7, self.quality_combo_box)), ((9, 3, self.run_button),) ) @@ -350,8 +358,8 @@ class FileTranscriberWidget(QWidget): self.transcribe_progress.connect(self.handle_transcribe_progress) - def on_model_changed(self, model_name: str): - self.selected_model_name = model_name + def on_quality_changed(self, quality: Quality): + self.selected_quality = quality def on_language_changed(self, language: str): self.selected_language = language @@ -369,8 +377,11 @@ class FileTranscriberWidget(QWidget): return self.run_button.setDisabled(True) + model_name = get_model_name( + self.selected_quality, self.selected_language) + logging.debug(f'Loading model: {model_name}') model = _whisper.load_model( - self.selected_model_name, on_download_model_chunk=self.on_download_model_progress) + model_name, on_download_model_chunk=self.on_download_model_progress) self.file_transcriber = FileTranscriber( model=model, file_path=self.file_path, @@ -412,7 +423,7 @@ class FileTranscriberWidget(QWidget): class RecordingTranscriberWidget(QWidget): current_status = RecordButton.Status.STOPPED - selected_model_name = 'tiny' + selected_quality = Quality.LOW selected_language = 'en' selected_device_id: Optional[int] selected_delay = 10 @@ -424,9 +435,9 @@ class RecordingTranscriberWidget(QWidget): layout = QGridLayout() - self.models_combo_box = ModelsComboBox( - default_model_name=self.selected_model_name) - self.models_combo_box.modelNameChanged.connect(self.on_model_changed) + self.quality_combo_box = QualityComboBox( + default_quality=self.selected_quality) + self.quality_combo_box.quality_changed.connect(self.on_quality_changed) self.languages_combo_box = LanguagesComboBox( default_language=self.selected_language) @@ -453,9 +464,9 @@ class RecordingTranscriberWidget(QWidget): self.text_box = TextDisplayBox() grid = ( - ((0, 5, FormLabel('Model:')), (5, 7, self.models_combo_box)), - ((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)), ((0, 5, FormLabel('Task:')), (5, 7, self.tasks_combo_box)), + ((0, 5, FormLabel('Language:')), (5, 7, self.languages_combo_box)), + ((0, 5, FormLabel('Quality:')), (5, 7, self.quality_combo_box)), ((0, 5, FormLabel('Microphone:')), (5, 7, self.audio_devices_combo_box)), ((0, 5, FormLabel('Delay:')), (5, 7, delays_combo_box)), @@ -489,8 +500,8 @@ class RecordingTranscriberWidget(QWidget): else: self.stop_recording() - def on_model_changed(self, model_name: str): - self.selected_model_name = model_name + def on_quality_changed(self, quality: Quality): + self.selected_quality = quality def on_language_changed(self, language: str): self.selected_language = language @@ -504,8 +515,11 @@ class RecordingTranscriberWidget(QWidget): def start_recording(self): self.record_button.setDisabled(True) + model_name = get_model_name( + self.selected_quality, self.selected_language) + logging.debug(f'Loading model: {model_name}') model = _whisper.load_model( - self.selected_model_name, on_download_model_chunk=self.on_download_model_chunk) + model_name, on_download_model_chunk=self.on_download_model_chunk) self.record_button.setDisabled(False) # Clear text box placeholder because the first chunk takes a while to process