From eb925683811a980e6a1e06190852b81e6a40271d Mon Sep 17 00:00:00 2001 From: Raivis Dejus Date: Sat, 15 Jun 2024 09:25:01 +0300 Subject: [PATCH] Add llm translations (#791) --- buzz/assets/translate_black.svg | 23 +++ buzz/buzz.py | 3 +- buzz/db/dao/transcription_segment_dao.py | 15 ++ buzz/db/entity/transcription_segment.py | 1 + buzz/db/helpers.py | 5 +- buzz/db/service/transcription_service.py | 4 + buzz/locale/lv_LV/LC_MESSAGES/buzz.po | 162 ++++++++++++------ buzz/schema.sql | 1 + buzz/settings/settings.py | 8 +- buzz/settings/shortcut.py | 1 + buzz/transcriber/file_transcriber.py | 13 +- buzz/transcriber/recording_transcriber.py | 5 +- buzz/transcriber/transcriber.py | 4 + buzz/transcriber/whisper_file_transcriber.py | 6 + buzz/translator.py | 87 ++++++++++ buzz/widgets/icon.py | 5 + buzz/widgets/main_window.py | 12 +- .../general_preferences_widget.py | 2 +- .../models/file_transcription_preferences.py | 23 ++- buzz/widgets/recording_transcriber_widget.py | 119 +++++++++++-- .../transcriber/advanced_settings_dialog.py | 68 +++++++- .../transcriber/initial_prompt_text_edit.py | 1 + .../transcription_options_group_box.py | 15 +- .../transcription_tasks_table_widget.py | 6 +- .../export_transcription_menu.py | 52 ++++-- .../transcription_segments_editor_widget.py | 77 ++++++++- .../transcription_view_mode_tool_button.py | 20 ++- .../transcription_viewer_widget.py | 123 ++++++++++++- tests/gui_test.py | 15 +- tests/translator_test.py | 116 +++++++++++++ .../widgets/export_transcription_menu_test.py | 4 +- .../folder_watch_preferences_widget_test.py | 3 + .../general_preferences_widget_test.py | 4 +- .../recording_transcriber_widget_test.py | 14 +- tests/widgets/shortcuts_editor_widget_test.py | 1 + .../transcription_task_folder_watcher_test.py | 6 + tests/widgets/transcription_viewer_test.py | 33 +++- 37 files changed, 922 insertions(+), 135 deletions(-) create mode 100644 buzz/assets/translate_black.svg create mode 100644 buzz/translator.py create mode 100644 tests/translator_test.py diff --git a/buzz/assets/translate_black.svg b/buzz/assets/translate_black.svg new file mode 100644 index 00000000..627853ca --- /dev/null +++ b/buzz/assets/translate_black.svg @@ -0,0 +1,23 @@ + + + + + \ No newline at end of file diff --git a/buzz/buzz.py b/buzz/buzz.py index 8cddf355..c3e388a7 100644 --- a/buzz/buzz.py +++ b/buzz/buzz.py @@ -6,7 +6,7 @@ import platform import sys from typing import TextIO -from platformdirs import user_log_dir, user_cache_dir +from platformdirs import user_log_dir, user_cache_dir, user_data_dir from buzz.assets import APP_BASE_DIR @@ -60,6 +60,7 @@ def main(): logging.debug("app_dir: %s", APP_BASE_DIR) logging.debug("log_dir: %s", log_dir) logging.debug("cache_dir: %s", user_cache_dir("Buzz")) + logging.debug("data_dir: %s", user_data_dir("Buzz")) app = Application(sys.argv) parse_command_line(app) diff --git a/buzz/db/dao/transcription_segment_dao.py b/buzz/db/dao/transcription_segment_dao.py index bfff19fa..1b28c6c7 100644 --- a/buzz/db/dao/transcription_segment_dao.py +++ b/buzz/db/dao/transcription_segment_dao.py @@ -24,3 +24,18 @@ class TranscriptionSegmentDAO(DAO[TranscriptionSegment]): ) query.bindValue(":transcription_id", str(transcription_id)) return self._execute_all(query) + + def update_segment_translation(self, segment_id: int, translation: str): + query = self._create_query() + query.prepare( + """ + UPDATE transcription_segment + SET translation = :translation + WHERE id = :id + """ + ) + + query.bindValue(":id", segment_id) + query.bindValue(":translation", translation) + if not query.exec(): + raise Exception(query.lastError().text()) diff --git a/buzz/db/entity/transcription_segment.py b/buzz/db/entity/transcription_segment.py index 0e2f5d1f..4af38a5e 100644 --- a/buzz/db/entity/transcription_segment.py +++ b/buzz/db/entity/transcription_segment.py @@ -8,5 +8,6 @@ class TranscriptionSegment(Entity): start_time: int end_time: int text: str + translation: str transcription_id: str id: int = -1 diff --git a/buzz/db/helpers.py b/buzz/db/helpers.py index 767fae9d..d0985865 100644 --- a/buzz/db/helpers.py +++ b/buzz/db/helpers.py @@ -53,13 +53,14 @@ def copy_transcriptions_from_json_to_sqlite(conn: Connection): for segment in task.segments: cursor.execute( """ - INSERT INTO transcription_segment (end_time, start_time, text, transcription_id) - VALUES (?, ?, ?, ?); + INSERT INTO transcription_segment (end_time, start_time, text, translation, transcription_id) + VALUES (?, ?, ?, ?, ?); """, ( segment.end, segment.start, segment.text, + segment.translation, transcription_id, ), ) diff --git a/buzz/db/service/transcription_service.py b/buzz/db/service/transcription_service.py index 00972c3d..d6f75274 100644 --- a/buzz/db/service/transcription_service.py +++ b/buzz/db/service/transcription_service.py @@ -39,9 +39,13 @@ class TranscriptionService: start_time=segment.start, end_time=segment.end, text=segment.text, + translation='', transcription_id=str(id), ) ) def get_transcription_segments(self, transcription_id: UUID): return self.transcription_segment_dao.get_segments(transcription_id) + + def update_segment_translation(self, segment_id: int, translation: str): + return self.transcription_segment_dao.update_segment_translation(segment_id, translation) diff --git a/buzz/locale/lv_LV/LC_MESSAGES/buzz.po b/buzz/locale/lv_LV/LC_MESSAGES/buzz.po index 5f5f66e9..6702ae06 100644 --- a/buzz/locale/lv_LV/LC_MESSAGES/buzz.po +++ b/buzz/locale/lv_LV/LC_MESSAGES/buzz.po @@ -8,8 +8,8 @@ msgid "" msgstr "" "Project-Id-Version: \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-06-10 21:58+0300\n" -"PO-Revision-Date: 2024-06-10 21:59+0300\n" +"POT-Creation-Date: 2024-06-14 18:59+0300\n" +"PO-Revision-Date: 2024-06-14 19:01+0300\n" "Last-Translator: \n" "Language-Team: \n" "Language: lv_LV\n" @@ -39,51 +39,63 @@ msgid "View Transcript Text" msgstr "Aplūkot atpazīto tekstu" #: buzz/settings/shortcut.py:23 +msgid "View Transcript Translation" +msgstr "Aplūkot tulkojumu" + +#: buzz/settings/shortcut.py:24 msgid "View Transcript Timestamps" msgstr "Aplūkot atpazīšanas laikus" -#: buzz/settings/shortcut.py:25 buzz/widgets/main_window_toolbar.py:60 +#: buzz/settings/shortcut.py:26 buzz/widgets/main_window_toolbar.py:60 #: buzz/widgets/main_window.py:222 msgid "Clear History" msgstr "Notīrīt vēsturi" -#: buzz/settings/shortcut.py:26 buzz/widgets/main_window_toolbar.py:52 +#: buzz/settings/shortcut.py:27 buzz/widgets/main_window_toolbar.py:52 msgid "Cancel Transcription" msgstr "Atcelt atpazīšanu" -#: buzz/widgets/transcription_tasks_table_widget.py:64 +#: buzz/widgets/transcription_tasks_table_widget.py:62 +msgid "In Progress" +msgstr "Apstrādā" + +#: buzz/widgets/transcription_tasks_table_widget.py:65 msgid "Completed" msgstr "Pabeigts" -#: buzz/widgets/transcription_tasks_table_widget.py:73 +#: buzz/widgets/transcription_tasks_table_widget.py:72 +msgid "Failed" +msgstr "Neizdevās" + +#: buzz/widgets/transcription_tasks_table_widget.py:75 msgid "Canceled" msgstr "Atcelts" -#: buzz/widgets/transcription_tasks_table_widget.py:75 +#: buzz/widgets/transcription_tasks_table_widget.py:77 msgid "Queued" msgstr "Ierindots" -#: buzz/widgets/transcription_tasks_table_widget.py:83 +#: buzz/widgets/transcription_tasks_table_widget.py:85 msgid "File Name / URL" msgstr "Fails / URL" -#: buzz/widgets/transcription_tasks_table_widget.py:95 +#: buzz/widgets/transcription_tasks_table_widget.py:97 msgid "Model" msgstr "Modelis" -#: buzz/widgets/transcription_tasks_table_widget.py:104 +#: buzz/widgets/transcription_tasks_table_widget.py:106 msgid "Task" msgstr "Uzdevums" -#: buzz/widgets/transcription_tasks_table_widget.py:113 +#: buzz/widgets/transcription_tasks_table_widget.py:115 msgid "Status" msgstr "Statuss" -#: buzz/widgets/transcription_tasks_table_widget.py:121 +#: buzz/widgets/transcription_tasks_table_widget.py:123 msgid "Date Added" msgstr "Pievienots" -#: buzz/widgets/transcription_tasks_table_widget.py:132 +#: buzz/widgets/transcription_tasks_table_widget.py:134 msgid "Date Completed" msgstr "Pabeigts" @@ -124,23 +136,27 @@ msgstr "Adrese nav derīga" msgid "The URL you entered is invalid." msgstr "Jūsu ievadītā URL adrese nav derīga." -#: buzz/widgets/recording_transcriber_widget.py:63 +#: buzz/widgets/recording_transcriber_widget.py:66 msgid "Live Recording" msgstr "Dzīvā ierakstīšana" -#: buzz/widgets/recording_transcriber_widget.py:112 +#: buzz/widgets/recording_transcriber_widget.py:125 msgid "Click Record to begin..." msgstr "Klikšķiniet Ierakstīt, lai sāktu..." -#: buzz/widgets/recording_transcriber_widget.py:124 +#: buzz/widgets/recording_transcriber_widget.py:128 +msgid "Waiting for AI translation..." +msgstr "Gaida MI tulkojumu..." + +#: buzz/widgets/recording_transcriber_widget.py:140 msgid "Microphone:" msgstr "Mikrofons:" -#: buzz/widgets/recording_transcriber_widget.py:319 +#: buzz/widgets/recording_transcriber_widget.py:391 msgid "An error occurred while starting a new recording:" msgstr "Sākot jaunu ierakstu notikusi kļūda:" -#: buzz/widgets/recording_transcriber_widget.py:323 +#: buzz/widgets/recording_transcriber_widget.py:395 msgid "" "Please check your audio devices or check the application logs for more " "information." @@ -157,8 +173,8 @@ msgid "" "Detected missing permissions, please check that snap permissions have been " "granted" msgstr "" -"Ne visi nepieciešamie moduļi darbojas korekti, iespējams nav piešķirtas " -"snap atļaujas" +"Ne visi nepieciešamie moduļi darbojas korekti, iespējams nav piešķirtas snap " +"atļaujas" #: buzz/widgets/snap_notice.py:16 msgid "" @@ -201,67 +217,113 @@ msgstr "" msgid "Select audio file" msgstr "Izvēlieties audio failu" -#: buzz/widgets/main_window.py:278 +#: buzz/widgets/main_window.py:280 #: buzz/widgets/preferences_dialog/models_preferences_widget.py:191 msgid "Error" msgstr "Kļūda" -#: buzz/widgets/main_window.py:278 +#: buzz/widgets/main_window.py:280 msgid "Unable to save OpenAI API key to keyring" msgstr "Neizdevās saglabāt OpenAI API atslēgu atslēgu saišķī" #: buzz/widgets/transcription_viewer/export_transcription_menu.py:42 -msgid "Save File" -msgstr "Saglabāt failu" - -#: buzz/widgets/transcription_viewer/export_transcription_menu.py:44 -msgid "Text files" -msgstr "Teksta faili" - -#: buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py:19 -msgid "View" -msgstr "Skats" - -#: buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py:27 -#: buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py:68 +#: buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py:34 +#: buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py:95 msgid "Text" msgstr "Teksts" -#: buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py:33 +#: buzz/widgets/transcription_viewer/export_transcription_menu.py:43 +#: buzz/widgets/transcription_viewer/export_transcription_menu.py:65 +#: buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py:40 +#: buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py:96 +msgid "Translation" +msgstr "Tulkojums" + +#: buzz/widgets/transcription_viewer/export_transcription_menu.py:79 +msgid "Save File" +msgstr "Saglabāt failu" + +#: buzz/widgets/transcription_viewer/export_transcription_menu.py:81 +msgid "Text files" +msgstr "Teksta faili" + +#: buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py:26 +msgid "View" +msgstr "Skats" + +#: buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py:46 msgid "Timestamps" msgstr "Laiks" -#: buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py:66 +#: buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py:93 msgid "Start" msgstr "Sākums" -#: buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py:67 +#: buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py:94 msgid "End" msgstr "Beigas" -#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:92 +#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:146 msgid "Export" msgstr "Eksportēt" +#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:160 +msgid "Translate" +msgstr "Tulkot" + +#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:250 +msgid "API Key Required" +msgstr "API atslēgas kļūda" + +#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:251 +msgid "Please enter OpenAI API Key in preferences" +msgstr "Lūdzu ievadiet OpenAI API atslēgu iestatījumos" + #: buzz/widgets/record_button.py:21 msgid "Stop" msgstr "Apturēt" -#: buzz/widgets/transcriber/advanced_settings_dialog.py:28 +#: buzz/widgets/transcriber/advanced_settings_dialog.py:33 msgid "Advanced Settings" msgstr "Papildu iestatījumi" -#: buzz/widgets/transcriber/advanced_settings_dialog.py:42 +#: buzz/widgets/transcriber/advanced_settings_dialog.py:37 +msgid "Speech recognition settings" +msgstr "Runas atpazīšanas iestatījumi" + +#: buzz/widgets/transcriber/advanced_settings_dialog.py:46 msgid "Comma-separated, e.g. \"0.0, 0.2, 0.4, 0.6, 0.8, 1.0\"" msgstr "Atdalīti ar komatu, piemēram, \"0.0, 0.2, 0.4, 0.6, 0.8, 1.0\"" -#: buzz/widgets/transcriber/advanced_settings_dialog.py:60 +#: buzz/widgets/transcriber/advanced_settings_dialog.py:55 msgid "Temperature:" msgstr "Temperatūra:" -#: buzz/widgets/transcriber/advanced_settings_dialog.py:61 +#: buzz/widgets/transcriber/advanced_settings_dialog.py:66 msgid "Initial Prompt:" -msgstr "Sākotnējais vaicājums:" +msgstr "" +"Sākotnējais\n" +"vaicājums:" + +#: buzz/widgets/transcriber/advanced_settings_dialog.py:68 +msgid "Translation settings" +msgstr "Tulkojuma iestatījumi" + +#: buzz/widgets/transcriber/advanced_settings_dialog.py:72 +msgid "Enable AI translation" +msgstr "Tulkot ar MI" + +#: buzz/widgets/transcriber/advanced_settings_dialog.py:84 +msgid "AI model:" +msgstr "AI modelis:" + +#: buzz/widgets/transcriber/advanced_settings_dialog.py:88 +msgid "Enter instructions for AI on how to translate..." +msgstr "Ievadiet tulkošanas norādes mākslīgajam intelektam..." + +#: buzz/widgets/transcriber/advanced_settings_dialog.py:92 +msgid "Instructions for AI:" +msgstr "Norādes MI:" #: buzz/widgets/transcriber/file_transcription_form_widget.py:42 msgid "Word-level timings" @@ -279,20 +341,20 @@ msgstr "Papildu iestatījumi..." msgid "Enter prompt..." msgstr "Ievadiet vaicājumu..." -#: buzz/widgets/transcriber/transcription_options_group_box.py:79 +#: buzz/widgets/transcriber/transcription_options_group_box.py:86 msgid "Model:" msgstr "Modelis:" -#: buzz/widgets/transcriber/transcription_options_group_box.py:83 +#: buzz/widgets/transcriber/transcription_options_group_box.py:90 msgid "Task:" msgstr "Uzdevums:" -#: buzz/widgets/transcriber/transcription_options_group_box.py:84 +#: buzz/widgets/transcriber/transcription_options_group_box.py:91 msgid "Language:" msgstr "Valoda:" #: buzz/widgets/transcriber/languages_combo_box.py:25 -#: buzz/transcriber/transcriber.py:149 +#: buzz/transcriber/transcriber.py:153 msgid "Detect Language" msgstr "Noteikt valodu" @@ -358,10 +420,10 @@ msgstr "OpenAI API atslēgas pārbaude" #: buzz/widgets/preferences_dialog/general_preferences_widget.py:129 msgid "" "Your API key is valid. Buzz will use this key to perform Whisper API " -"transcriptions." +"transcriptions and AI translations with ChatGPT." msgstr "" "Jūsu API atslēga ir derīga. Buzz izmantos to runas atpazīšanai ar Whisper " -"API." +"API un tulkošanai ar ChatGPT." #: buzz/widgets/preferences_dialog/general_preferences_widget.py:156 msgid "Select Export Folder" diff --git a/buzz/schema.sql b/buzz/schema.sql index ee18b4f0..d6046b54 100644 --- a/buzz/schema.sql +++ b/buzz/schema.sql @@ -23,6 +23,7 @@ CREATE TABLE transcription_segment ( end_time INT DEFAULT 0, start_time INT DEFAULT 0, text TEXT NOT NULL, + translation TEXT DEFAULT '', transcription_id TEXT, FOREIGN KEY (transcription_id) REFERENCES transcription(id) ON DELETE CASCADE ); diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index 43eeab25..62fb6e26 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -11,7 +11,7 @@ class Settings: def __init__(self, application=""): self.settings = QSettings(APP_NAME, application) self.settings.sync() - logging.debug(f"settings filename: {self.settings.fileName()}") + logging.debug(f"Settings filename: {self.settings.fileName()}") class Key(enum.Enum): RECORDING_TRANSCRIBER_TASK = "recording-transcriber/task" @@ -19,6 +19,9 @@ class Settings: RECORDING_TRANSCRIBER_LANGUAGE = "recording-transcriber/language" RECORDING_TRANSCRIBER_TEMPERATURE = "recording-transcriber/temperature" RECORDING_TRANSCRIBER_INITIAL_PROMPT = "recording-transcriber/initial-prompt" + RECORDING_TRANSCRIBER_ENABLE_LLM_TRANSLATION = "recording-transcriber/enable-llm-translation" + RECORDING_TRANSCRIBER_LLM_MODEL = "recording-transcriber/llm-model" + RECORDING_TRANSCRIBER_LLM_PROMPT = "recording-transcriber/llm-prompt" RECORDING_TRANSCRIBER_EXPORT_ENABLED = "recording-transcriber/export-enabled" RECORDING_TRANSCRIBER_EXPORT_FOLDER = "recording-transcriber/export-folder" @@ -27,6 +30,9 @@ class Settings: FILE_TRANSCRIBER_LANGUAGE = "file-transcriber/language" FILE_TRANSCRIBER_TEMPERATURE = "file-transcriber/temperature" FILE_TRANSCRIBER_INITIAL_PROMPT = "file-transcriber/initial-prompt" + FILE_TRANSCRIBER_ENABLE_LLM_TRANSLATION = "file-transcriber/enable-llm-translation" + FILE_TRANSCRIBER_LLM_MODEL = "file-transcriber/llm-model" + FILE_TRANSCRIBER_LLM_PROMPT = "file-transcriber/llm-prompt" FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = "file-transcriber/word-level-timings" FILE_TRANSCRIBER_EXPORT_FORMATS = "file-transcriber/export-formats" diff --git a/buzz/settings/shortcut.py b/buzz/settings/shortcut.py index 023448bd..f61d8d42 100644 --- a/buzz/settings/shortcut.py +++ b/buzz/settings/shortcut.py @@ -20,6 +20,7 @@ class Shortcut(str, enum.Enum): OPEN_PREFERENCES_WINDOW = ("Ctrl+,", _("Open Preferences Window")) VIEW_TRANSCRIPT_TEXT = ("Ctrl+E", _("View Transcript Text")) + VIEW_TRANSCRIPT_TRANSLATION = ("Ctrl+L", _("View Transcript Translation")) VIEW_TRANSCRIPT_TIMESTAMPS = ("Ctrl+T", _("View Transcript Timestamps")) CLEAR_HISTORY = ("Ctrl+S", _("Clear History")) diff --git a/buzz/transcriber/file_transcriber.py b/buzz/transcriber/file_transcriber.py index 07151fb4..ed7f9505 100644 --- a/buzz/transcriber/file_transcriber.py +++ b/buzz/transcriber/file_transcriber.py @@ -103,7 +103,12 @@ class FileTranscriber(QObject): # TODO: Move to transcription service -def write_output(path: str, segments: List[Segment], output_format: OutputFormat): +def write_output( + path: str, + segments: List[Segment], + output_format: OutputFormat, + segment_key: str = 'text' +): logging.debug( "Writing transcription output, path = %s, output format = %s, number of segments = %s", path, @@ -114,7 +119,7 @@ def write_output(path: str, segments: List[Segment], output_format: OutputFormat with open(path, "w", encoding="utf-8") as file: if output_format == OutputFormat.TXT: for i, segment in enumerate(segments): - file.write(segment.text) + file.write(getattr(segment, segment_key)) file.write("\n") elif output_format == OutputFormat.VTT: @@ -123,7 +128,7 @@ def write_output(path: str, segments: List[Segment], output_format: OutputFormat file.write( f"{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n" ) - file.write(f"{segment.text}\n\n") + file.write(f"{getattr(segment, segment_key)}\n\n") elif output_format == OutputFormat.SRT: for i, segment in enumerate(segments): @@ -131,7 +136,7 @@ def write_output(path: str, segments: List[Segment], output_format: OutputFormat file.write( f'{to_timestamp(segment.start, ms_separator=",")} --> {to_timestamp(segment.end, ms_separator=",")}\n' ) - file.write(f"{segment.text}\n\n") + file.write(f"{getattr(segment, segment_key)}\n\n") logging.debug("Written transcription output") diff --git a/buzz/transcriber/recording_transcriber.py b/buzz/transcriber/recording_transcriber.py index def9a331..a827346f 100644 --- a/buzz/transcriber/recording_transcriber.py +++ b/buzz/transcriber/recording_transcriber.py @@ -49,6 +49,7 @@ class RecordingTranscriber(QObject): def start(self): model_path = self.model_path + keep_samples = int(0.15 * self.sample_rate) if self.transcription_options.model.model_type == ModelType.WHISPER: model = whisper.load_model(model_path) @@ -82,7 +83,7 @@ class RecordingTranscriber(QObject): self.mutex.acquire() if self.queue.size >= self.n_batch_samples: samples = self.queue[: self.n_batch_samples] - self.queue = self.queue[self.n_batch_samples :] + self.queue = self.queue[self.n_batch_samples - keep_samples:] self.mutex.release() logging.debug( @@ -124,7 +125,7 @@ class RecordingTranscriber(QObject): whisper_segments, info = model.transcribe( audio=samples, language=self.transcription_options.language - if self.transcription_options.language is not "" + if self.transcription_options.language != "" else None, task=self.transcription_options.task.value, temperature=self.transcription_options.temperature, diff --git a/buzz/transcriber/transcriber.py b/buzz/transcriber/transcriber.py index 7077f376..ba076387 100644 --- a/buzz/transcriber/transcriber.py +++ b/buzz/transcriber/transcriber.py @@ -25,6 +25,7 @@ class Segment: start: int # start time in ms end: int # end time in ms text: str + translation: str = "" LANGUAGES = { @@ -142,6 +143,9 @@ class TranscriptionOptions: openai_access_token: str = field( default="", metadata=config(exclude=Exclude.ALWAYS) ) + enable_llm_translation: bool = False + llm_prompt: str = "" + llm_model: str = "" def humanize_language(language: str) -> str: diff --git a/buzz/transcriber/whisper_file_transcriber.py b/buzz/transcriber/whisper_file_transcriber.py index e1f7f620..bd024fcf 100644 --- a/buzz/transcriber/whisper_file_transcriber.py +++ b/buzz/transcriber/whisper_file_transcriber.py @@ -121,6 +121,7 @@ class WhisperFileTranscriber(FileTranscriber): start=int(segment.get("start") * 1000), end=int(segment.get("end") * 1000), text=segment.get("text"), + translation="" ) for segment in result.get("segments") ] @@ -149,6 +150,7 @@ class WhisperFileTranscriber(FileTranscriber): start=int(word.start * 1000), end=int(word.end * 1000), text=word.word, + translation="" ) ) else: @@ -157,6 +159,7 @@ class WhisperFileTranscriber(FileTranscriber): start=int(segment.start * 1000), end=int(segment.end * 1000), text=segment.text, + translation="" ) ) @@ -181,6 +184,7 @@ class WhisperFileTranscriber(FileTranscriber): start=int(word.start * 1000), end=int(word.end * 1000), text=word.word.strip(), + translation="" ) for segment in result.segments for word in segment.words @@ -200,6 +204,7 @@ class WhisperFileTranscriber(FileTranscriber): start=int(segment.get("start") * 1000), end=int(segment.get("end") * 1000), text=segment.get("text"), + translation="" ) for segment in segments ] @@ -226,6 +231,7 @@ class WhisperFileTranscriber(FileTranscriber): start=segment.get("start"), end=segment.get("end"), text=segment.get("text"), + translation="" ) for segment in segments_dict ] diff --git a/buzz/translator.py b/buzz/translator.py new file mode 100644 index 00000000..5be35579 --- /dev/null +++ b/buzz/translator.py @@ -0,0 +1,87 @@ +import logging +import queue + +from typing import Optional +from openai import OpenAI +from PyQt6.QtCore import QObject, pyqtSignal + +from buzz.settings.settings import Settings +from buzz.store.keyring_store import get_password, Key +from buzz.transcriber.transcriber import TranscriptionOptions +from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog + + +class Translator(QObject): + translation = pyqtSignal(str, int) + finished = pyqtSignal() + is_running = False + + def __init__( + self, + transcription_options: TranscriptionOptions, + advanced_settings_dialog: AdvancedSettingsDialog, + parent: Optional[QObject] = None, + ) -> None: + super().__init__(parent) + + logging.debug(f"Translator init: {transcription_options}") + + self.transcription_options = transcription_options + self.advanced_settings_dialog = advanced_settings_dialog + self.advanced_settings_dialog.transcription_options_changed.connect( + self.on_transcription_options_changed + ) + + self.queue = queue.Queue() + + settings = Settings() + custom_openai_base_url = settings.value( + key=Settings.Key.CUSTOM_OPENAI_BASE_URL, default_value="" + ) + openai_api_key = get_password(Key.OPENAI_API_KEY) + self.openai_client = OpenAI( + api_key=openai_api_key, + base_url=custom_openai_base_url if custom_openai_base_url else None + ) + + def start(self): + logging.debug("Starting translation queue") + + self.is_running = True + + while self.is_running: + try: + transcript, transcript_id = self.queue.get(timeout=1) + except queue.Empty: + continue + + completion = self.openai_client.chat.completions.create( + model=self.transcription_options.llm_model, + messages=[ + {"role": "system", "content": self.transcription_options.llm_prompt}, + {"role": "user", "content": transcript} + ] + ) + + logging.debug(f"Received translation response: {completion}") + + if completion.choices and completion.choices[0].message: + next_translation = completion.choices[0].message.content + else: + logging.error(f"Translation error! Server response: {completion}") + next_translation = "Translation error, see logs!" + + self.translation.emit(next_translation, transcript_id) + + self.finished.emit() + + def on_transcription_options_changed( + self, transcription_options: TranscriptionOptions + ): + self.transcription_options = transcription_options + + def enqueue(self, transcript: str, transcript_id: Optional[int] = None): + self.queue.put((transcript, transcript_id)) + + def stop(self): + self.is_running = False diff --git a/buzz/widgets/icon.py b/buzz/widgets/icon.py index e86fa7fc..86c5667d 100644 --- a/buzz/widgets/icon.py +++ b/buzz/widgets/icon.py @@ -74,6 +74,11 @@ class FileDownloadIcon(Icon): super().__init__(get_path("assets/file_download_black_24dp.svg"), parent) +class TranslateIcon(Icon): + def __init__(self, parent: QWidget): + super().__init__(get_path("assets/translate_black.svg"), parent) + + class VisibilityIcon(Icon): def __init__(self, parent: QWidget): super().__init__( diff --git a/buzz/widgets/main_window.py b/buzz/widgets/main_window.py index 95707886..960153bd 100644 --- a/buzz/widgets/main_window.py +++ b/buzz/widgets/main_window.py @@ -141,6 +141,8 @@ class MainWindow(QMainWindow): self.folder_watcher.task_found.connect(self.add_task) self.folder_watcher.find_tasks() + self.transcription_viewer_widget = None + if os.environ.get('SNAP_NAME', '') == 'buzz': logging.debug("Running in a snap environment") self.check_linux_permissions() @@ -267,6 +269,8 @@ class MainWindow(QMainWindow): self.on_openai_access_token_changed ) file_transcriber_window.show() + file_transcriber_window.raise_() + file_transcriber_window.activateWindow() @staticmethod def on_openai_access_token_changed(access_token: str): @@ -347,14 +351,14 @@ class MainWindow(QMainWindow): self.open_transcription_viewer(transcription) def open_transcription_viewer(self, transcription: Transcription): - transcription_viewer_widget = TranscriptionViewerWidget( + self.transcription_viewer_widget = TranscriptionViewerWidget( transcription=transcription, transcription_service=self.transcription_service, shortcuts=self.shortcuts, parent=self, flags=Qt.WindowType.Window, ) - transcription_viewer_widget.show() + self.transcription_viewer_widget.show() def add_task(self, task: FileTranscriptionTask): self.transcription_service.create_transcription(task) @@ -396,6 +400,10 @@ class MainWindow(QMainWindow): self.transcriber_worker.stop() self.transcriber_thread.quit() self.transcriber_thread.wait() + + if self.transcription_viewer_widget is not None: + self.transcription_viewer_widget.close() + super().closeEvent(event) def save_geometry(self): diff --git a/buzz/widgets/preferences_dialog/general_preferences_widget.py b/buzz/widgets/preferences_dialog/general_preferences_widget.py index 2021fbb8..7296a568 100644 --- a/buzz/widgets/preferences_dialog/general_preferences_widget.py +++ b/buzz/widgets/preferences_dialog/general_preferences_widget.py @@ -126,7 +126,7 @@ class GeneralPreferencesWidget(QWidget): QMessageBox.information( self, _("OpenAI API Key Test"), - _("Your API key is valid. Buzz will use this key to perform Whisper API transcriptions."), + _("Your API key is valid. Buzz will use this key to perform Whisper API transcriptions and AI translations with ChatGPT."), ) def on_test_openai_api_key_failure(self, error: str): diff --git a/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py b/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py index e706f240..228bb8b7 100644 --- a/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py +++ b/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py @@ -21,6 +21,9 @@ class FileTranscriptionPreferences: word_level_timings: bool temperature: Tuple[float, ...] initial_prompt: str + enable_llm_translation: bool + llm_prompt: str + llm_model: str output_formats: Set["OutputFormat"] def save(self, settings: QSettings) -> None: @@ -30,6 +33,9 @@ class FileTranscriptionPreferences: settings.setValue("word_level_timings", self.word_level_timings) settings.setValue("temperature", self.temperature) settings.setValue("initial_prompt", self.initial_prompt) + settings.setValue("enable_llm_translation", self.enable_llm_translation) + settings.setValue("llm_model", self.llm_model) + settings.setValue("llm_prompt", self.llm_prompt) settings.setValue( "output_formats", [output_format.value for output_format in self.output_formats], @@ -44,10 +50,16 @@ class FileTranscriptionPreferences: ) word_level_timings_value = settings.value("word_level_timings", False) - word_level_timings = False if word_level_timings_value == "false" else bool(word_level_timings_value) + word_level_timings = False if word_level_timings_value == "false" \ + else bool(word_level_timings_value) temperature = settings.value("temperature", DEFAULT_WHISPER_TEMPERATURE) initial_prompt = settings.value("initial_prompt", "") + enable_llm_translation_value = settings.value("enable_llm_translation", False) + enable_llm_translation = False if enable_llm_translation_value == "false" \ + else bool(enable_llm_translation_value) + llm_model = settings.value("llm_model", "") + llm_prompt = settings.value("llm_prompt", "") output_formats = settings.value("output_formats", []) or [] return FileTranscriptionPreferences( language=language, @@ -58,6 +70,9 @@ class FileTranscriptionPreferences: word_level_timings=word_level_timings, temperature=temperature, initial_prompt=initial_prompt, + enable_llm_translation=enable_llm_translation, + llm_model=llm_model, + llm_prompt=llm_prompt, output_formats=set( [OutputFormat(output_format) for output_format in output_formats] ), @@ -74,6 +89,9 @@ class FileTranscriptionPreferences: language=transcription_options.language, temperature=transcription_options.temperature, initial_prompt=transcription_options.initial_prompt, + enable_llm_translation=transcription_options.enable_llm_translation, + llm_model=transcription_options.llm_model, + llm_prompt=transcription_options.llm_prompt, word_level_timings=transcription_options.word_level_timings, model=transcription_options.model, output_formats=file_transcription_options.output_formats, @@ -91,6 +109,9 @@ class FileTranscriptionPreferences: language=self.language, temperature=self.temperature, initial_prompt=self.initial_prompt, + enable_llm_translation=self.enable_llm_translation, + llm_model=self.llm_model, + llm_prompt=self.llm_prompt, word_level_timings=self.word_level_timings, model=self.model, openai_access_token=openai_access_token, diff --git a/buzz/widgets/recording_transcriber_widget.py b/buzz/widgets/recording_transcriber_widget.py index 1437c66f..3eada2fc 100644 --- a/buzz/widgets/recording_transcriber_widget.py +++ b/buzz/widgets/recording_transcriber_widget.py @@ -24,6 +24,7 @@ from buzz.transcriber.transcriber import ( DEFAULT_WHISPER_TEMPERATURE, Task, ) +from buzz.translator import Translator from buzz.widgets.audio_devices_combo_box import AudioDevicesComboBox from buzz.widgets.audio_meter_widget import AudioMeterWidget from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog @@ -59,6 +60,8 @@ class RecordingTranscriberWidget(QWidget): layout = QVBoxLayout(self) + self.translation_thread = None + self.translator = None self.current_status = self.RecordingStatus.STOPPED self.setWindowTitle(_("Live Recording")) @@ -99,6 +102,16 @@ class RecordingTranscriberWidget(QWidget): default_value=DEFAULT_WHISPER_TEMPERATURE, ), word_level_timings=False, + enable_llm_translation=self.settings.value( + key=Settings.Key.RECORDING_TRANSCRIBER_ENABLE_LLM_TRANSLATION, + default_value=False, + ), + llm_model=self.settings.value( + key=Settings.Key.RECORDING_TRANSCRIBER_LLM_MODEL, default_value="" + ), + llm_prompt=self.settings.value( + key=Settings.Key.RECORDING_TRANSCRIBER_LLM_PROMPT, default_value="" + ), ) self.audio_devices_combo_box = AudioDevicesComboBox(self) @@ -108,15 +121,18 @@ class RecordingTranscriberWidget(QWidget): self.record_button = RecordButton(self) self.record_button.clicked.connect(self.on_record_button_clicked) - self.text_box = TextDisplayBox(self) - self.text_box.setPlaceholderText(_("Click Record to begin...")) + self.transcription_text_box = TextDisplayBox(self) + self.transcription_text_box.setPlaceholderText(_("Click Record to begin...")) - transcription_options_group_box = TranscriptionOptionsGroupBox( + self.translation_text_box = TextDisplayBox(self) + self.translation_text_box.setPlaceholderText(_("Waiting for AI translation...")) + + self.transcription_options_group_box = TranscriptionOptionsGroupBox( default_transcription_options=self.transcription_options, model_types=model_types, parent=self, ) - transcription_options_group_box.transcription_options_changed.connect( + self.transcription_options_group_box.transcription_options_changed.connect( self.on_transcription_options_changed ) @@ -129,17 +145,22 @@ class RecordingTranscriberWidget(QWidget): record_button_layout.addWidget(self.audio_meter_widget) record_button_layout.addWidget(self.record_button) - layout.addWidget(transcription_options_group_box) + layout.addWidget(self.transcription_options_group_box) layout.addLayout(recording_options_layout) layout.addLayout(record_button_layout) - layout.addWidget(self.text_box) + layout.addWidget(self.transcription_text_box) + layout.addWidget(self.translation_text_box) + + if not self.transcription_options.enable_llm_translation: + self.translation_text_box.hide() self.setLayout(layout) - self.setFixedSize(self.sizeHint()) + self.resize(450, 500) self.reset_recording_amplitude_listener() - self.export_file = None + self.transcript_export_file = None + self.translation_export_file = None self.export_enabled = self.settings.value( key=Settings.Key.RECORDING_TRANSCRIBER_EXPORT_ENABLED, default_value=False, @@ -168,13 +189,19 @@ class RecordingTranscriberWidget(QWidget): if not os.path.isdir(export_folder): self.export_enabled = False - self.export_file = os.path.join(export_folder, export_file_name) + self.transcript_export_file = os.path.join(export_folder, export_file_name) + self.translation_export_file = self.transcript_export_file.replace(".txt", ".translated.txt") def on_transcription_options_changed( self, transcription_options: TranscriptionOptions ): self.transcription_options = transcription_options + if self.transcription_options.enable_llm_translation: + self.translation_text_box.show() + else: + self.translation_text_box.hide() + def on_device_changed(self, device_id: int): self.selected_device_id = device_id self.reset_recording_amplitude_listener() @@ -259,6 +286,28 @@ class RecordingTranscriberWidget(QWidget): self.transcriber.error.connect(self.transcription_thread.quit) self.transcriber.error.connect(self.transcriber.deleteLater) + if self.transcription_options.enable_llm_translation: + self.translation_thread = QThread() + + self.translator = Translator( + self.transcription_options, + self.transcription_options_group_box.advanced_settings_dialog, + ) + + self.translator.moveToThread(self.translation_thread) + + self.translation_thread.started.connect(self.translator.start) + self.translation_thread.finished.connect( + self.translation_thread.deleteLater + ) + + self.translator.finished.connect(self.translation_thread.quit) + self.translator.finished.connect(self.translator.deleteLater) + + self.translator.translation.connect(self.on_next_translation) + + self.translation_thread.start() + self.transcription_thread.start() def on_download_model_progress(self, progress: Tuple[float, float]): @@ -288,22 +337,45 @@ class RecordingTranscriberWidget(QWidget): self.set_recording_status_stopped() self.record_button.setDisabled(False) + @staticmethod + def strip_newlines(text): + return text.replace('\r\n', os.linesep).replace('\n', os.linesep) + def on_next_transcription(self, text: str): text = text.strip() if len(text) > 0: - self.text_box.moveCursor(QTextCursor.MoveOperation.End) - if len(self.text_box.toPlainText()) > 0: - self.text_box.insertPlainText("\n\n") - self.text_box.insertPlainText(text) - self.text_box.moveCursor(QTextCursor.MoveOperation.End) + if self.translator is not None: + self.translator.enqueue(text) + + self.transcription_text_box.moveCursor(QTextCursor.MoveOperation.End) + if len(self.transcription_text_box.toPlainText()) > 0: + self.transcription_text_box.insertPlainText("\n\n") + self.transcription_text_box.insertPlainText(text) + self.transcription_text_box.moveCursor(QTextCursor.MoveOperation.End) if self.export_enabled: - with open(self.export_file, "a") as f: + with open(self.transcript_export_file, "a") as f: + f.write(text + "\n\n") + + def on_next_translation(self, text: str, _: Optional[int] = None): + if len(text) > 0: + self.translation_text_box.moveCursor(QTextCursor.MoveOperation.End) + if len(self.translation_text_box.toPlainText()) > 0: + self.translation_text_box.insertPlainText("\n\n") + self.translation_text_box.insertPlainText(self.strip_newlines(text)) + self.translation_text_box.moveCursor(QTextCursor.MoveOperation.End) + + if self.export_enabled: + with open(self.translation_export_file, "a") as f: f.write(text + "\n\n") def stop_recording(self): if self.transcriber is not None: self.transcriber.stop_recording() + + if self.translator is not None: + self.translator.stop() + # Disable record button until the transcription is actually stopped in the background self.record_button.setDisabled(True) @@ -341,7 +413,7 @@ class RecordingTranscriberWidget(QWidget): def reset_recording_controls(self): # Clear text box placeholder because the first chunk takes a while to process - self.text_box.setPlaceholderText("") + self.transcription_text_box.setPlaceholderText("") self.reset_record_button() self.reset_model_download() @@ -361,6 +433,9 @@ class RecordingTranscriberWidget(QWidget): self.recording_amplitude_listener.deleteLater() self.recording_amplitude_listener = None + if self.translator is not None: + self.translator.stop() + self.settings.set_value( Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, self.transcription_options.language, @@ -379,5 +454,17 @@ class RecordingTranscriberWidget(QWidget): self.settings.set_value( Settings.Key.RECORDING_TRANSCRIBER_MODEL, self.transcription_options.model ) + self.settings.set_value( + Settings.Key.RECORDING_TRANSCRIBER_ENABLE_LLM_TRANSLATION, + self.transcription_options.enable_llm_translation, + ) + self.settings.set_value( + Settings.Key.RECORDING_TRANSCRIBER_LLM_MODEL, + self.transcription_options.llm_model, + ) + self.settings.set_value( + Settings.Key.RECORDING_TRANSCRIBER_LLM_PROMPT, + self.transcription_options.llm_prompt, + ) return super().closeEvent(event) diff --git a/buzz/widgets/transcriber/advanced_settings_dialog.py b/buzz/widgets/transcriber/advanced_settings_dialog.py index 2d86c3a3..964e5776 100644 --- a/buzz/widgets/transcriber/advanced_settings_dialog.py +++ b/buzz/widgets/transcriber/advanced_settings_dialog.py @@ -3,12 +3,16 @@ from PyQt6.QtWidgets import ( QDialog, QWidget, QDialogButtonBox, + QCheckBox, + QPlainTextEdit, QFormLayout, + QLabel, ) from buzz.locale import _ from buzz.model_loader import ModelType from buzz.transcriber.transcriber import TranscriptionOptions +from buzz.settings.settings import Settings from buzz.widgets.line_edit import LineEdit from buzz.widgets.transcriber.initial_prompt_text_edit import InitialPromptTextEdit from buzz.widgets.transcriber.temperature_validator import TemperatureValidator @@ -24,16 +28,16 @@ class AdvancedSettingsDialog(QDialog): super().__init__(parent) self.transcription_options = transcription_options + self.settings = Settings() self.setWindowTitle(_("Advanced Settings")) - button_box = QDialogButtonBox( - QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self - ) - button_box.accepted.connect(self.accept) - layout = QFormLayout(self) + transcription_settings_title= _("Speech recognition settings") + transcription_settings_title_label = QLabel(f"

{transcription_settings_title}

", self) + layout.addRow("", transcription_settings_title_label) + default_temperature_text = ", ".join( [str(temp) for temp in transcription_options.temperature] ) @@ -48,6 +52,8 @@ class AdvancedSettingsDialog(QDialog): transcription_options.model.model_type == ModelType.WHISPER ) + layout.addRow(_("Temperature:"), self.temperature_line_edit) + self.initial_prompt_text_edit = InitialPromptTextEdit( transcription_options.initial_prompt, transcription_options.model.model_type, @@ -57,12 +63,43 @@ class AdvancedSettingsDialog(QDialog): self.on_initial_prompt_changed ) - layout.addRow(_("Temperature:"), self.temperature_line_edit) layout.addRow(_("Initial Prompt:"), self.initial_prompt_text_edit) + + translation_settings_title= _("Translation settings") + translation_settings_title_label = QLabel(f"

{translation_settings_title}

", self) + layout.addRow("", translation_settings_title_label) + + self.enable_llm_translation_checkbox = QCheckBox(_("Enable AI translation")) + self.enable_llm_translation_checkbox.setChecked(self.transcription_options.enable_llm_translation) + self.enable_llm_translation_checkbox.stateChanged.connect(self.on_enable_llm_translation_changed) + layout.addRow("", self.enable_llm_translation_checkbox) + + self.llm_model_line_edit = LineEdit(self.transcription_options.llm_model, self) + self.llm_model_line_edit.textChanged.connect( + self.on_llm_model_changed + ) + self.llm_model_line_edit.setMinimumWidth(170) + self.llm_model_line_edit.setEnabled(self.transcription_options.enable_llm_translation) + self.llm_model_line_edit.setPlaceholderText("gpt-3.5-turbo") + layout.addRow(_("AI model:"), self.llm_model_line_edit) + + self.llm_prompt_text_edit = QPlainTextEdit(self.transcription_options.llm_prompt) + self.llm_prompt_text_edit.setEnabled(self.transcription_options.enable_llm_translation) + self.llm_prompt_text_edit.setPlaceholderText(_("Enter instructions for AI on how to translate...")) + self.llm_prompt_text_edit.setMinimumWidth(170) + self.llm_prompt_text_edit.setFixedHeight(115) + self.llm_prompt_text_edit.textChanged.connect(self.on_llm_prompt_changed) + layout.addRow(_("Instructions for AI:"), self.llm_prompt_text_edit) + + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self + ) + button_box.accepted.connect(self.accept) + layout.addWidget(button_box) self.setLayout(layout) - self.setFixedSize(self.sizeHint()) + self.resize(self.sizeHint()) def on_temperature_changed(self, text: str): try: @@ -77,3 +114,20 @@ class AdvancedSettingsDialog(QDialog): self.initial_prompt_text_edit.toPlainText() ) self.transcription_options_changed.emit(self.transcription_options) + + def on_enable_llm_translation_changed(self, state): + self.transcription_options.enable_llm_translation = state == 2 + self.transcription_options_changed.emit(self.transcription_options) + + self.llm_model_line_edit.setEnabled(self.transcription_options.enable_llm_translation) + self.llm_prompt_text_edit.setEnabled(self.transcription_options.enable_llm_translation) + + def on_llm_model_changed(self, text: str): + self.transcription_options.llm_model = text + self.transcription_options_changed.emit(self.transcription_options) + + def on_llm_prompt_changed(self): + self.transcription_options.llm_prompt = ( + self.llm_prompt_text_edit.toPlainText() + ) + self.transcription_options_changed.emit(self.transcription_options) diff --git a/buzz/widgets/transcriber/initial_prompt_text_edit.py b/buzz/widgets/transcriber/initial_prompt_text_edit.py index dd077305..26959f4c 100644 --- a/buzz/widgets/transcriber/initial_prompt_text_edit.py +++ b/buzz/widgets/transcriber/initial_prompt_text_edit.py @@ -10,3 +10,4 @@ class InitialPromptTextEdit(QPlainTextEdit): self.setPlaceholderText(_("Enter prompt...")) self.setEnabled(model_type.supports_initial_prompt) self.setMinimumWidth(350) + self.setFixedHeight(115) diff --git a/buzz/widgets/transcriber/transcription_options_group_box.py b/buzz/widgets/transcriber/transcription_options_group_box.py index 7c191d4b..ce61de17 100644 --- a/buzz/widgets/transcriber/transcription_options_group_box.py +++ b/buzz/widgets/transcriber/transcription_options_group_box.py @@ -39,6 +39,13 @@ class TranscriptionOptionsGroupBox(QGroupBox): ) self.model_type_combo_box.changed.connect(self.on_model_type_changed) + self.advanced_settings_dialog = AdvancedSettingsDialog( + transcription_options=self.transcription_options, parent=self + ) + self.advanced_settings_dialog.transcription_options_changed.connect( + self.on_transcription_options_changed + ) + self.whisper_model_size_combo_box = QComboBox(self) self.whisper_model_size_combo_box.addItems( [size.value.title() for size in WhisperModelSize] @@ -102,13 +109,7 @@ class TranscriptionOptionsGroupBox(QGroupBox): self.transcription_options_changed.emit(self.transcription_options) def open_advanced_settings(self): - dialog = AdvancedSettingsDialog( - transcription_options=self.transcription_options, parent=self - ) - dialog.transcription_options_changed.connect( - self.on_transcription_options_changed - ) - dialog.exec() + self.advanced_settings_dialog.exec() def on_transcription_options_changed( self, transcription_options: TranscriptionOptions diff --git a/buzz/widgets/transcription_tasks_table_widget.py b/buzz/widgets/transcription_tasks_table_widget.py index f6d3f8ea..b0ddb6b5 100644 --- a/buzz/widgets/transcription_tasks_table_widget.py +++ b/buzz/widgets/transcription_tasks_table_widget.py @@ -59,7 +59,8 @@ def format_record_status_text(record: QSqlRecord) -> str: status = FileTranscriptionTask.Status(record.value("status")) match status: case FileTranscriptionTask.Status.IN_PROGRESS: - return f'{_("In Progress")} ({record.value("progress") :.0%})' + in_progress_label = _("In Progress") + return f'{in_progress_label} ({record.value("progress") :.0%})' case FileTranscriptionTask.Status.COMPLETED: status = _("Completed") started_at = record.value("time_started") @@ -68,7 +69,8 @@ def format_record_status_text(record: QSqlRecord) -> str: status += f" ({TranscriptionTasksTableWidget.format_timedelta(datetime.fromisoformat(completed_at) - datetime.fromisoformat(started_at))})" return status case FileTranscriptionTask.Status.FAILED: - return f'{_("Failed")} ({record.value("error_message")})' + failed_label = _("Failed") + return f'{failed_label} ({record.value("error_message")})' case FileTranscriptionTask.Status.CANCELED: return _("Canceled") case FileTranscriptionTask.Status.QUEUED: diff --git a/buzz/widgets/transcription_viewer/export_transcription_menu.py b/buzz/widgets/transcription_viewer/export_transcription_menu.py index 55e9c8f5..42ac50ff 100644 --- a/buzz/widgets/transcription_viewer/export_transcription_menu.py +++ b/buzz/widgets/transcription_viewer/export_transcription_menu.py @@ -1,3 +1,4 @@ +import logging from PyQt6.QtGui import QAction from PyQt6.QtWidgets import QWidget, QMenu, QFileDialog @@ -23,15 +24,48 @@ class ExportTranscriptionMenu(QMenu): self.transcription = transcription self.transcription_service = transcription_service - actions = [ - QAction(text=output_format.value.upper(), parent=self) - for output_format in OutputFormat + self.segments = [ + Segment( + start=segment.start_time, + end=segment.end_time, + text=segment.text, + translation=segment.translation) + for segment in self.transcription_service.get_transcription_segments( + transcription_id=self.transcription.id_as_uuid + ) ] + + if self.segments and len(self.segments[0].translation) > 0: + text_label = _("Text") + translation_label = _("Translation") + actions = [ + action + for output_format in OutputFormat + for action in [ + QAction(text=f"{output_format.value.upper()} - {text_label}", parent=self), + QAction(text=f"{output_format.value.upper()} - {translation_label}", parent=self) + ] + ] + else: + actions = [ + QAction(text=output_format.value.upper(), parent=self) + for output_format in OutputFormat + ] self.addActions(actions) self.triggered.connect(self.on_menu_triggered) + @staticmethod + def extract_format_and_segment_key(action_text: str): + parts = action_text.split('-') + output_format = parts[0].strip() + label = parts[1].strip() if len(parts) > 1 else None + segment_key = 'translation' if label == _('Translation') else 'text' + + return output_format, segment_key + def on_menu_triggered(self, action: QAction): - output_format = OutputFormat[action.text()] + output_format_value, segment_key = self.extract_format_and_segment_key(action.text()) + output_format = OutputFormat[output_format_value] default_path = self.transcription.get_output_file_path( output_format=output_format @@ -47,15 +81,9 @@ class ExportTranscriptionMenu(QMenu): if output_file_path == "": return - segments = [ - Segment(start=segment.start_time, end=segment.end_time, text=segment.text) - for segment in self.transcription_service.get_transcription_segments( - transcription_id=self.transcription.id_as_uuid - ) - ] - write_output( path=output_file_path, - segments=segments, + segments=self.segments, output_format=output_format, + segment_key=segment_key ) diff --git a/buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py b/buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py index 4f4f2b0d..11ec25ff 100644 --- a/buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py +++ b/buzz/widgets/transcription_viewer/transcription_segments_editor_widget.py @@ -1,18 +1,22 @@ import enum +import logging from dataclasses import dataclass from typing import Optional from uuid import UUID from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex, QItemSelection from PyQt6.QtSql import QSqlTableModel, QSqlRecord +from PyQt6.QtGui import QFontMetrics, QTextOption from PyQt6.QtWidgets import ( QWidget, QTableView, QStyledItemDelegate, QAbstractItemView, + QTextEdit, ) from buzz.locale import _ +from buzz.translator import Translator from buzz.transcriber.file_transcriber import to_timestamp @@ -21,6 +25,7 @@ class Column(enum.Enum): END = enum.auto() START = enum.auto() TEXT = enum.auto() + TRANSLATION = enum.auto() TRANSCRIPTION_ID = enum.auto() @@ -37,6 +42,18 @@ class TimeStampDelegate(QStyledItemDelegate): return to_timestamp(value) +class WordWrapDelegate(QStyledItemDelegate): + def createEditor(self, parent, option, index): + editor = QTextEdit(parent) + editor.setWordWrapMode(QTextOption.WrapMode.WordWrap) + editor.setAcceptRichText(False) + + return editor + + def setModelData(self, editor, model, index): + model.setData(index, editor.toPlainText()) + + class TranscriptionSegmentModel(QSqlTableModel): def __init__(self, transcription_id: UUID): super().__init__() @@ -52,20 +69,31 @@ class TranscriptionSegmentModel(QSqlTableModel): class TranscriptionSegmentsEditorWidget(QTableView): + PARENT_PADDINGS = 40 segment_selected = pyqtSignal(QSqlRecord) - def __init__(self, transcription_id: UUID, parent: Optional[QWidget]): + def __init__( + self, + transcription_id: UUID, + translator: Translator, + parent: Optional[QWidget] + ): super().__init__(parent) + self.translator = translator + self.translator.translation.connect(self.update_translation) + model = TranscriptionSegmentModel(transcription_id=transcription_id) self.setModel(model) timestamp_delegate = TimeStampDelegate() + word_wrap_delegate = WordWrapDelegate() self.column_definitions: list[ColDef] = [ ColDef("start", _("Start"), Column.START, delegate=timestamp_delegate), ColDef("end", _("End"), Column.END, delegate=timestamp_delegate), - ColDef("text", _("Text"), Column.TEXT), + ColDef("text", _("Text"), Column.TEXT, delegate=word_wrap_delegate), + ColDef("translation", _("Translation"), Column.TRANSLATION, delegate=word_wrap_delegate), ] for i in range(model.columnCount()): @@ -90,9 +118,52 @@ class TranscriptionSegmentsEditorWidget(QTableView): self.selectionModel().selectionChanged.connect(self.on_selection_changed) model.select() + self.has_translations = self.has_non_empty_translation() + # Show start before end self.horizontalHeader().swapSections(1, 2) - self.resizeColumnsToContents() + + font_metrics = QFontMetrics(self.font()) + max_row_height = font_metrics.height() * 4 + for row in range(self.model().rowCount()): + self.setRowHeight(row, max_row_height) + + self.setColumnWidth(Column.START.value, 95) + self.setColumnWidth(Column.END.value, 95) + + self.setWordWrap(True) + + def has_non_empty_translation(self) -> bool: + for i in range(self.model().rowCount()): + if self.model().record(i).value("translation").strip(): + return True + return False + + def resizeEvent(self, event): + super().resizeEvent(event) + + if not self.has_translations: + self.hideColumn(Column.TRANSLATION.value) + else: + self.showColumn(Column.TRANSLATION.value) + + text_column_count = 2 if self.has_translations else 1 + + time_column_widths = self.columnWidth(Column.START.value) + self.columnWidth(Column.END.value) + text_column_width = ( + int((self.parent().width() - self.PARENT_PADDINGS - time_column_widths) / text_column_count)) + + self.setColumnWidth(Column.TEXT.value, text_column_width) + self.setColumnWidth(Column.TRANSLATION.value, text_column_width) + + def update_translation(self, translation: str, segment_id: Optional[int] = None): + self.has_translations = True + self.resizeEvent(None) + + for row in range(self.model().rowCount()): + if self.model().record(row).value("id") == segment_id: + self.model().setData(self.model().index(row, Column.TRANSLATION.value), translation) + break def on_selection_changed( self, selected: QItemSelection, _deselected: QItemSelection diff --git a/buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py b/buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py index c38fff89..24f7c94c 100644 --- a/buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py +++ b/buzz/widgets/transcription_viewer/transcription_view_mode_tool_button.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Optional from PyQt6.QtCore import pyqtSignal, Qt @@ -10,8 +11,14 @@ from buzz.settings.shortcuts import Shortcuts from buzz.widgets.icon import VisibilityIcon +class ViewMode(Enum): + TEXT = "Text" + TRANSLATION = "Translation" + TIMESTAMPS = "Timestamps" + + class TranscriptionViewModeToolButton(QToolButton): - view_mode_changed = pyqtSignal(bool) # is_timestamps? + view_mode_changed = pyqtSignal(ViewMode) def __init__(self, shortcuts: Shortcuts, parent: Optional[QWidget] = None): super().__init__(parent) @@ -26,12 +33,19 @@ class TranscriptionViewModeToolButton(QToolButton): menu.addAction( _("Text"), QKeySequence(shortcuts.get(Shortcut.VIEW_TRANSCRIPT_TEXT)), - lambda: self.view_mode_changed.emit(False), + lambda: self.view_mode_changed.emit(ViewMode.TEXT), + ) + + menu.addAction( + _("Translation"), + QKeySequence(shortcuts.get(Shortcut.VIEW_TRANSCRIPT_TRANSLATION)), + lambda: self.view_mode_changed.emit(ViewMode.TRANSLATION) ) menu.addAction( _("Timestamps"), QKeySequence(shortcuts.get(Shortcut.VIEW_TRANSCRIPT_TIMESTAMPS)), - lambda: self.view_mode_changed.emit(True), + lambda: self.view_mode_changed.emit(ViewMode.TIMESTAMPS), ) + self.setMenu(menu) diff --git a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py index ffed3e59..42e4bd8a 100644 --- a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py +++ b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py @@ -1,8 +1,9 @@ +import logging import platform from typing import Optional from uuid import UUID -from PyQt6.QtCore import Qt +from PyQt6.QtCore import Qt, QThread from PyQt6.QtGui import QFont from PyQt6.QtMultimedia import QMediaPlayer from PyQt6.QtSql import QSqlRecord @@ -11,6 +12,7 @@ from PyQt6.QtWidgets import ( QVBoxLayout, QToolButton, QLabel, + QMessageBox, ) from buzz.locale import _ @@ -18,25 +20,36 @@ from buzz.db.entity.transcription import Transcription from buzz.db.service.transcription_service import TranscriptionService from buzz.paths import file_path_as_title from buzz.settings.shortcuts import Shortcuts +from buzz.settings.settings import Settings +from buzz.store.keyring_store import get_password, Key from buzz.widgets.audio_player import AudioPlayer from buzz.widgets.icon import ( FileDownloadIcon, + TranslateIcon ) +from buzz.translator import Translator from buzz.widgets.text_display_box import TextDisplayBox from buzz.widgets.toolbar import ToolBar +from buzz.transcriber.transcriber import TranscriptionOptions +from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog from buzz.widgets.transcription_viewer.export_transcription_menu import ( ExportTranscriptionMenu, ) +from buzz.widgets.preferences_dialog.models.file_transcription_preferences import ( + FileTranscriptionPreferences, +) from buzz.widgets.transcription_viewer.transcription_segments_editor_widget import ( TranscriptionSegmentsEditorWidget, ) from buzz.widgets.transcription_viewer.transcription_view_mode_tool_button import ( TranscriptionViewModeToolButton, + ViewMode ) class TranscriptionViewerWidget(QWidget): transcription: Transcription + settings = Settings() def __init__( self, @@ -55,10 +68,45 @@ class TranscriptionViewerWidget(QWidget): self.setWindowTitle(file_path_as_title(transcription.file)) - self.is_showing_timestamps = True + self.translation_thread = None + self.translator = None + self.view_mode = ViewMode.TIMESTAMPS + + self.openai_access_token = get_password(Key.OPENAI_API_KEY) + + preferences = self.load_preferences() + + ( + self.transcription_options, + self.file_transcription_options, + ) = preferences.to_transcription_options( + openai_access_token=self.openai_access_token, + ) + + self.transcription_options_dialog = AdvancedSettingsDialog( + transcription_options=self.transcription_options, parent=self + ) + self.transcription_options_dialog.transcription_options_changed.connect( + self.on_transcription_options_changed + ) + + self.translator = Translator( + self.transcription_options, + self.transcription_options_dialog, + ) + + self.translation_thread = QThread() + self.translator.moveToThread(self.translation_thread) + + self.translation_thread.started.connect(self.translator.start) + + self.translation_thread.start() self.table_widget = TranscriptionSegmentsEditorWidget( - transcription_id=UUID(hex=transcription.id), parent=self + transcription_id=UUID(hex=transcription.id), + translator=self.translator, + + parent=self ) self.table_widget.segment_selected.connect(self.on_segment_selected) @@ -102,6 +150,16 @@ class TranscriptionViewerWidget(QWidget): export_tool_button.setPopupMode(QToolButton.ToolButtonPopupMode.InstantPopup) toolbar.addWidget(export_tool_button) + translate_button = QToolButton() + translate_button.setText(_("Translate")) + translate_button.setIcon(TranslateIcon(self)) + translate_button.setToolButtonStyle( + Qt.ToolButtonStyle.ToolButtonTextBesideIcon + ) + translate_button.clicked.connect(self.on_translate_button_clicked) + + toolbar.addWidget(translate_button) + layout.setMenuBar(toolbar) layout.addWidget(self.table_widget) @@ -114,10 +172,10 @@ class TranscriptionViewerWidget(QWidget): self.reset_view() def reset_view(self): - if self.is_showing_timestamps: + if self.view_mode == ViewMode.TIMESTAMPS: self.text_display_box.hide() self.table_widget.show() - else: + elif self.view_mode == ViewMode.TEXT: segments = self.transcription_service.get_transcription_segments( transcription_id=self.transcription.id_as_uuid ) @@ -126,9 +184,19 @@ class TranscriptionViewerWidget(QWidget): ) self.text_display_box.show() self.table_widget.hide() + else: # ViewMode.TRANSLATION + # TODO add check for if translation exists + segments = self.transcription_service.get_transcription_segments( + transcription_id=self.transcription.id_as_uuid + ) + self.text_display_box.setPlainText( + " ".join(segment.translation.strip() for segment in segments) + ) + self.text_display_box.show() + self.table_widget.hide() - def on_view_mode_changed(self, is_timestamps: bool) -> None: - self.is_showing_timestamps = is_timestamps + def on_view_mode_changed(self, view_mode: ViewMode) -> None: + self.view_mode = view_mode self.reset_view() def on_segment_selected(self, segment: QSqlRecord): @@ -154,3 +222,44 @@ class TranscriptionViewerWidget(QWidget): ) if current_segment is not None: self.current_segment_label.setText(current_segment.value("text")) + + def load_preferences(self): + self.settings.settings.beginGroup("file_transcriber") + preferences = FileTranscriptionPreferences.load(settings=self.settings.settings) + self.settings.settings.endGroup() + return preferences + + def open_advanced_settings(self): + self.transcription_options_dialog.show() + + def on_transcription_options_changed( + self, transcription_options: TranscriptionOptions + ): + self.transcription_options = transcription_options + + def on_translate_button_clicked(self): + if len(self.openai_access_token) == 0: + QMessageBox.information( + self, + _("API Key Required"), + _("Please enter OpenAI API Key in preferences") + ) + + return + + if self.transcription_options.llm_model == "" or self.transcription_options.llm_prompt == "": + self.transcription_options_dialog.show() + return + + segments = self.table_widget.segments() + for segment in segments: + self.translator.enqueue(segment.value("text"), segment.value("id")) + + def closeEvent(self, event): + self.hide() + + self.translator.stop() + self.translation_thread.quit() + self.translation_thread.wait() + + super().closeEvent(event) diff --git a/tests/gui_test.py b/tests/gui_test.py index f3f11000..780fa850 100644 --- a/tests/gui_test.py +++ b/tests/gui_test.py @@ -111,7 +111,11 @@ class TestAdvancedSettingsDialog: def test_should_update_advanced_settings(self, qtbot: QtBot): dialog = AdvancedSettingsDialog( transcription_options=TranscriptionOptions( - temperature=(0.0, 0.8), initial_prompt="prompt" + temperature=(0.0, 0.8), + initial_prompt="prompt", + enable_llm_translation=False, + llm_model="", + llm_prompt="" ) ) qtbot.add_widget(dialog) @@ -122,12 +126,21 @@ class TestAdvancedSettingsDialog: assert dialog.windowTitle() == _("Advanced Settings") assert dialog.temperature_line_edit.text() == "0.0, 0.8" assert dialog.initial_prompt_text_edit.toPlainText() == "prompt" + assert dialog.enable_llm_translation_checkbox.isChecked() is False + assert dialog.llm_model_line_edit.text() == "" + assert dialog.llm_prompt_text_edit.toPlainText() == "" dialog.temperature_line_edit.setText("0.0, 0.8, 1.0") dialog.initial_prompt_text_edit.setPlainText("new prompt") + dialog.enable_llm_translation_checkbox.setChecked(True) + dialog.llm_model_line_edit.setText("model") + dialog.llm_prompt_text_edit.setPlainText("Please translate this text") assert transcription_options_mock.call_args[0][0].temperature == (0.0, 0.8, 1.0) assert transcription_options_mock.call_args[0][0].initial_prompt == "new prompt" + assert transcription_options_mock.call_args[0][0].enable_llm_translation is True + assert transcription_options_mock.call_args[0][0].llm_model == "model" + assert transcription_options_mock.call_args[0][0].llm_prompt == "Please translate this text" class TestTemperatureValidator: diff --git a/tests/translator_test.py b/tests/translator_test.py new file mode 100644 index 00000000..56db2fc3 --- /dev/null +++ b/tests/translator_test.py @@ -0,0 +1,116 @@ +import time +import pytest +from queue import Empty +from unittest.mock import Mock, patch, create_autospec + +from PyQt6.QtCore import QThread + +from buzz.translator import Translator +from buzz.transcriber.transcriber import TranscriptionOptions +from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog + + +class TestTranslator: + @patch('buzz.translator.OpenAI', autospec=True) + @patch('buzz.translator.queue.Queue', autospec=True) + def test_start(self, mock_queue, mock_openai): + def side_effect(*args, **kwargs): + side_effect.call_count += 1 + + if side_effect.call_count >= 5: + translator.is_running = False + + if side_effect.call_count < 3: + raise Empty + return "Hello, how are you?", None + + side_effect.call_count = 0 + + mock_queue.get.side_effect = side_effect + mock_chat = Mock() + mock_openai.return_value.chat = mock_chat + mock_chat.completions.create.return_value = Mock( + choices=[Mock(message=Mock(content="AI Translated: Hello, how are you?"))] + ) + + transcription_options = TranscriptionOptions( + enable_llm_translation=False, + llm_model="llama3", + llm_prompt="Please translate this text:", + ) + translator = Translator( + transcription_options, + AdvancedSettingsDialog( + transcription_options=transcription_options, parent=None + ) + ) + translator.queue = mock_queue + + translator.start() + + mock_queue.get.assert_called() + mock_chat.completions.create.assert_called() + + @patch('buzz.translator.OpenAI', autospec=True) + def test_translator(self, mock_openai, qtbot): + + self.on_next_translation_called = False + + def on_next_translation(text: str): + self.on_next_translation_called = True + assert text.startswith("AI Translated:") + + mock_chat = Mock() + mock_openai.return_value.chat = mock_chat + mock_chat.completions.create.return_value = Mock( + choices=[Mock(message=Mock(content="AI Translated: Hello, how are you?"))] + ) + + self.translation_thread = QThread() + self.transcription_options = TranscriptionOptions( + enable_llm_translation=False, + llm_model="llama3", + llm_prompt="Please translate this text:", + ) + + self.translator = Translator( + self.transcription_options, + AdvancedSettingsDialog( + transcription_options=self.transcription_options, parent=None + ) + ) + + self.translator.moveToThread(self.translation_thread) + + self.translation_thread.started.connect(self.translator.start) + self.translation_thread.finished.connect( + self.translation_thread.deleteLater + ) + + self.translator.finished.connect(self.translation_thread.quit) + self.translator.finished.connect(self.translator.deleteLater) + + self.translator.translation.connect(on_next_translation) + + self.translation_thread.start() + + time.sleep(3) + assert self.translator.is_running + + self.translator.enqueue("Hello, how are you?") + + def translation_signal_received(): + assert self.on_next_translation_called + + qtbot.wait_until(translation_signal_received, timeout=60 * 1000) + + if self.translator is not None: + self.translator.stop() + self.translator.deleteLater() + + if self.translation_thread is not None: + self.translation_thread.quit() + self.translation_thread.deleteLater() + + # Wait to clean-up threads + time.sleep(3) diff --git a/tests/widgets/export_transcription_menu_test.py b/tests/widgets/export_transcription_menu_test.py index 4430f6f4..819ea054 100644 --- a/tests/widgets/export_transcription_menu_test.py +++ b/tests/widgets/export_transcription_menu_test.py @@ -30,9 +30,9 @@ class TestExportTranscriptionMenu: whisper_model_size=WhisperModelSize.SMALL.value, ) ) - transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", str(id))) + transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", "", str(id))) transcription_segment_dao.insert( - TranscriptionSegment(299, 329, "venue dans", str(id)) + TranscriptionSegment(299, 329, "venue dans", "", str(id)) ) return transcription_dao.find_by_id(str(id)) diff --git a/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py b/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py index 5e052bb2..3b62cc14 100644 --- a/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py +++ b/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py @@ -29,6 +29,9 @@ class TestFolderWatchPreferencesWidget: word_level_timings=False, temperature=DEFAULT_WHISPER_TEMPERATURE, initial_prompt="", + enable_llm_translation=False, + llm_model="", + llm_prompt="", output_formats=set(), ), ), diff --git a/tests/widgets/preferences_dialog/general_preferences_widget_test.py b/tests/widgets/preferences_dialog/general_preferences_widget_test.py index 2aab14c2..0510830d 100644 --- a/tests/widgets/preferences_dialog/general_preferences_widget_test.py +++ b/tests/widgets/preferences_dialog/general_preferences_widget_test.py @@ -101,10 +101,10 @@ class TestGeneralPreferencesWidget: assert openai_base_url == "" assert widget.custom_openai_base_url_line_edit.text() == "" - widget.custom_openai_base_url_line_edit.setText("https://localhost:8000/v1") + widget.custom_openai_base_url_line_edit.setText("http://localhost:11434/v1") updated_openai_base_url = settings.value( key=Settings.Key.CUSTOM_OPENAI_BASE_URL, default_value="" ) - assert updated_openai_base_url == "https://localhost:8000/v1" + assert updated_openai_base_url == "http://localhost:11434/v1" diff --git a/tests/widgets/recording_transcriber_widget_test.py b/tests/widgets/recording_transcriber_widget_test.py index c4576109..b23a138a 100644 --- a/tests/widgets/recording_transcriber_widget_test.py +++ b/tests/widgets/recording_transcriber_widget_test.py @@ -45,10 +45,10 @@ class TestRecordingTranscriberWidget: widget.device_sample_rate = 16_000 qtbot.add_widget(widget) - assert len(widget.text_box.toPlainText()) == 0 + assert len(widget.transcription_text_box.toPlainText()) == 0 def assert_text_box_contains_text(): - assert len(widget.text_box.toPlainText()) > 0 + assert len(widget.transcription_text_box.toPlainText()) > 0 widget.record_button.click() qtbot.wait_until(callback=assert_text_box_contains_text, timeout=60 * 1000) @@ -56,7 +56,7 @@ class TestRecordingTranscriberWidget: with qtbot.wait_signal(widget.transcription_thread.finished, timeout=60 * 1000): widget.stop_recording() - assert len(widget.text_box.toPlainText()) > 0 + assert len(widget.transcription_text_box.toPlainText()) > 0 widget.close() # on CI transcribed output is garbage, so we check if there is anything @@ -90,10 +90,10 @@ class TestRecordingTranscriberWidget: widget.export_enabled = True qtbot.add_widget(widget) - assert len(widget.text_box.toPlainText()) == 0 + assert len(widget.transcription_text_box.toPlainText()) == 0 def assert_text_box_contains_text(): - assert len(widget.text_box.toPlainText()) > 0 + assert len(widget.transcription_text_box.toPlainText()) > 0 widget.record_button.click() qtbot.wait_until(callback=assert_text_box_contains_text, timeout=60 * 1000) @@ -101,9 +101,9 @@ class TestRecordingTranscriberWidget: with qtbot.wait_signal(widget.transcription_thread.finished, timeout=60 * 1000): widget.stop_recording() - assert len(widget.text_box.toPlainText()) > 0 + assert len(widget.transcription_text_box.toPlainText()) > 0 - with open(widget.export_file, 'r') as file: + with open(widget.transcript_export_file, 'r') as file: contents = file.read() assert len(contents) > 0 diff --git a/tests/widgets/shortcuts_editor_widget_test.py b/tests/widgets/shortcuts_editor_widget_test.py index 98a0d4b4..024ae306 100644 --- a/tests/widgets/shortcuts_editor_widget_test.py +++ b/tests/widgets/shortcuts_editor_widget_test.py @@ -35,6 +35,7 @@ class TestShortcutsEditorWidget: (_("Import URL"), "Ctrl+U"), (_("Open Preferences Window"), "Ctrl+,"), (_("View Transcript Text"), "Ctrl+E"), + (_("View Transcript Translation"), "Ctrl+L"), (_("View Transcript Timestamps"), "Ctrl+T"), (_("Clear History"), "Ctrl+S"), (_("Cancel Transcription"), "Ctrl+X"), diff --git a/tests/widgets/transcription_task_folder_watcher_test.py b/tests/widgets/transcription_task_folder_watcher_test.py index 940c5bef..c6b2d5da 100644 --- a/tests/widgets/transcription_task_folder_watcher_test.py +++ b/tests/widgets/transcription_task_folder_watcher_test.py @@ -46,6 +46,9 @@ class TestTranscriptionTaskFolderWatcher: word_level_timings=False, temperature=DEFAULT_WHISPER_TEMPERATURE, initial_prompt="", + enable_llm_translation=False, + llm_model="", + llm_prompt="", output_formats=set(), ), ), @@ -86,6 +89,9 @@ class TestTranscriptionTaskFolderWatcher: word_level_timings=False, temperature=DEFAULT_WHISPER_TEMPERATURE, initial_prompt="", + enable_llm_translation=False, + llm_model="", + llm_prompt="", output_formats=set(), ), ), diff --git a/tests/widgets/transcription_viewer_test.py b/tests/widgets/transcription_viewer_test.py index 739edaa6..99980b30 100644 --- a/tests/widgets/transcription_viewer_test.py +++ b/tests/widgets/transcription_viewer_test.py @@ -1,12 +1,18 @@ import uuid +import time import pytest from pytestqt.qtbot import QtBot +from buzz.locale import _ from buzz.db.entity.transcription import Transcription from buzz.db.entity.transcription_segment import TranscriptionSegment from buzz.model_loader import ModelType, WhisperModelSize from buzz.transcriber.transcriber import Task +from buzz.widgets.transcription_viewer.transcription_view_mode_tool_button import ( + TranscriptionViewModeToolButton, + ViewMode +) from buzz.widgets.transcription_viewer.transcription_segments_editor_widget import ( TranscriptionSegmentsEditorWidget, ) @@ -32,9 +38,9 @@ class TestTranscriptionViewerWidget: whisper_model_size=WhisperModelSize.SMALL.value, ) ) - transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", str(id))) + transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", "", str(id))) transcription_segment_dao.insert( - TranscriptionSegment(299, 329, "venue dans", str(id)) + TranscriptionSegment(299, 329, "venue dans", "", str(id)) ) return transcription_dao.find_by_id(str(id)) @@ -55,6 +61,7 @@ class TestTranscriptionViewerWidget: assert editor.model().index(0, 1).data() == 299 assert editor.model().index(0, 2).data() == 40 assert editor.model().index(0, 3).data() == "Bien" + widget.close() def test_should_update_segment_text( self, qtbot, transcription, transcription_service, shortcuts @@ -68,3 +75,25 @@ class TestTranscriptionViewerWidget: assert isinstance(editor, TranscriptionSegmentsEditorWidget) editor.model().setData(editor.model().index(0, 3), "Biens") + widget.close() + + def test_text_button_changes_view_mode( + self, qtbot, transcription, transcription_service, shortcuts + ): + widget = TranscriptionViewerWidget( + transcription, transcription_service, shortcuts + ) + qtbot.add_widget(widget) + + view_mode_tool_button = widget.findChild(TranscriptionViewModeToolButton) + menu = view_mode_tool_button.menu() + + text_action = next(action for action in menu.actions() if action.text() == _("Text")) + text_action.trigger() + assert widget.view_mode == ViewMode.TEXT + + text_action = next(action for action in menu.actions() if action.text() == _("Translation")) + text_action.trigger() + assert widget.view_mode == ViewMode.TRANSLATION + + widget.close()