diff --git a/.coveragerc b/.coveragerc index 566ba584..c682e6aa 100644 --- a/.coveragerc +++ b/.coveragerc @@ -8,5 +8,12 @@ omit = deepmultilingualpunctuation/* ctc_forced_aligner/* +[report] +exclude_also = + if sys.platform == "win32": + if platform.system\(\) == "Windows": + if platform.system\(\) == "Linux": + if platform.system\(\) == "Darwin": + [html] directory = coverage/html diff --git a/pyproject.toml b/pyproject.toml index 7e0fc933..6a39cd84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,14 @@ sources = {"demucs_repo/demucs" = "demucs"} requires = ["hatchling", "cmake>=4.2.0,<5", "polib>=1.2.0,<2", "pybind11", "setuptools>=80.9.0"] build-backend = "hatchling.build" +[tool.coverage.report] +exclude_also = [ + "if sys.platform == \"win32\":", + "if platform.system\\(\\) == \"Windows\":", + "if platform.system\\(\\) == \"Linux\":", + "if platform.system\\(\\) == \"Darwin\":", +] + [tool.ruff] exclude = [ "**/whisper.cpp", diff --git a/tests/mock_sounddevice.py b/tests/mock_sounddevice.py index 6b4824dc..820199f9 100644 --- a/tests/mock_sounddevice.py +++ b/tests/mock_sounddevice.py @@ -105,18 +105,25 @@ class MockInputStream: **kwargs, ): self._stop_event = Event() - self.thread = Thread(target=self.target) self.callback = callback + # Pre-load audio on the calling (main) thread to avoid calling + # subprocess.run (fork) from a background thread on macOS, which + # can cause a segfault when Qt is running. + sample_rate = whisper_audio.SAMPLE_RATE + file_path = os.path.join( + os.path.dirname(__file__), "../testdata/whisper-french.mp3" + ) + self._audio = whisper_audio.load_audio(file_path, sr=sample_rate) + + self.thread = Thread(target=self.target) + def start(self): self.thread.start() def target(self): sample_rate = whisper_audio.SAMPLE_RATE - file_path = os.path.join( - os.path.dirname(__file__), "../testdata/whisper-french.mp3" - ) - audio = whisper_audio.load_audio(file_path, sr=sample_rate) + audio = self._audio chunk_duration_secs = 1 diff --git a/tests/model_loader_test.py b/tests/model_loader_test.py index f91f79f4..886e7d18 100644 --- a/tests/model_loader_test.py +++ b/tests/model_loader_test.py @@ -1,9 +1,13 @@ +import io import os +import threading +import time import pytest -from unittest.mock import patch +from unittest.mock import patch, MagicMock, call from buzz.model_loader import ( ModelDownloader, + HuggingfaceDownloadMonitor, TranscriptionModel, ModelType, WhisperModelSize, @@ -287,6 +291,347 @@ class TestTranscriptionModelOpenPath: mock_call.assert_called_once_with(['open', '/some/path']) +class TestTranscriptionModelOpenFileLocation: + def test_whisper_opens_parent_directory(self): + model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY) + with patch.object(model, 'get_local_model_path', return_value="/some/path/model.pt"), \ + patch.object(TranscriptionModel, 'open_path') as mock_open: + model.open_file_location() + mock_open.assert_called_once_with(path="/some/path") + + def test_hugging_face_opens_grandparent_directory(self): + model = TranscriptionModel( + model_type=ModelType.HUGGING_FACE, + hugging_face_model_id="openai/whisper-tiny" + ) + with patch.object(model, 'get_local_model_path', return_value="/cache/models/snapshot/model.safetensors"), \ + patch.object(TranscriptionModel, 'open_path') as mock_open: + model.open_file_location() + # For HF: dirname(path) -> /cache/models/snapshot, then open_path(dirname(...)) -> /cache/models + mock_open.assert_called_once_with(path="/cache/models") + + def test_faster_whisper_opens_grandparent_directory(self): + model = TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY) + with patch.object(model, 'get_local_model_path', return_value="/cache/models/snapshot/model.bin"), \ + patch.object(TranscriptionModel, 'open_path') as mock_open: + model.open_file_location() + # For FW: dirname(path) -> /cache/models/snapshot, then open_path(dirname(...)) -> /cache/models + mock_open.assert_called_once_with(path="/cache/models") + + def test_no_model_path_does_nothing(self): + model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY) + with patch.object(model, 'get_local_model_path', return_value=None), \ + patch.object(TranscriptionModel, 'open_path') as mock_open: + model.open_file_location() + mock_open.assert_not_called() + + +class TestTranscriptionModelDeleteLocalFile: + def test_whisper_model_removes_file(self, tmp_path): + model_file = tmp_path / "model.pt" + model_file.write_bytes(b"fake model data") + model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY) + with patch.object(model, 'get_local_model_path', return_value=str(model_file)): + model.delete_local_file() + assert not model_file.exists() + + def test_whisper_cpp_custom_removes_file(self, tmp_path): + model_file = tmp_path / "ggml-model-whisper-custom.bin" + model_file.write_bytes(b"fake model data") + model = TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.CUSTOM) + with patch.object(model, 'get_local_model_path', return_value=str(model_file)): + model.delete_local_file() + assert not model_file.exists() + + def test_whisper_cpp_non_custom_removes_bin_file(self, tmp_path): + model_file = tmp_path / "ggml-tiny.bin" + model_file.write_bytes(b"fake model data") + model = TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY) + with patch.object(model, 'get_local_model_path', return_value=str(model_file)): + model.delete_local_file() + assert not model_file.exists() + + def test_whisper_cpp_non_custom_removes_coreml_files(self, tmp_path): + model_file = tmp_path / "ggml-tiny.bin" + model_file.write_bytes(b"fake model data") + coreml_zip = tmp_path / "ggml-tiny-encoder.mlmodelc.zip" + coreml_zip.write_bytes(b"fake zip") + coreml_dir = tmp_path / "ggml-tiny-encoder.mlmodelc" + coreml_dir.mkdir() + model = TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY) + with patch.object(model, 'get_local_model_path', return_value=str(model_file)): + model.delete_local_file() + assert not model_file.exists() + assert not coreml_zip.exists() + assert not coreml_dir.exists() + + def test_hugging_face_removes_directory_tree(self, tmp_path): + # Structure: models--repo/snapshots/abc/model.safetensors + # delete_local_file does dirname(dirname(model_path)) = snapshots_dir + repo_dir = tmp_path / "models--repo" + snapshots_dir = repo_dir / "snapshots" + snapshot_dir = snapshots_dir / "abc123" + snapshot_dir.mkdir(parents=True) + model_file = snapshot_dir / "model.safetensors" + model_file.write_bytes(b"fake model") + + model = TranscriptionModel( + model_type=ModelType.HUGGING_FACE, + hugging_face_model_id="some/repo" + ) + with patch.object(model, 'get_local_model_path', return_value=str(model_file)): + model.delete_local_file() + # Two dirs up from model_file: dirname(dirname(model_file)) = snapshots_dir + assert not snapshots_dir.exists() + + def test_faster_whisper_removes_directory_tree(self, tmp_path): + repo_dir = tmp_path / "faster-whisper-tiny" + snapshots_dir = repo_dir / "snapshots" + snapshot_dir = snapshots_dir / "abc123" + snapshot_dir.mkdir(parents=True) + model_file = snapshot_dir / "model.bin" + model_file.write_bytes(b"fake model") + + model = TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY) + with patch.object(model, 'get_local_model_path', return_value=str(model_file)): + model.delete_local_file() + # Two dirs up from model_file: dirname(dirname(model_file)) = snapshots_dir + assert not snapshots_dir.exists() + + +class TestHuggingfaceDownloadMonitorFileSize: + def _make_monitor(self, tmp_path): + model_root = str(tmp_path / "models--test" / "snapshots" / "abc") + os.makedirs(model_root, exist_ok=True) + progress = MagicMock() + progress.emit = MagicMock() + monitor = HuggingfaceDownloadMonitor( + model_root=model_root, + progress=progress, + total_file_size=100 * 1024 * 1024 + ) + return monitor + + def test_emits_progress_for_tmp_files(self, tmp_path): + from buzz.model_loader import model_root_dir as orig_root + monitor = self._make_monitor(tmp_path) + + # Create a tmp file in model_root_dir + with patch('buzz.model_loader.model_root_dir', str(tmp_path)): + tmp_file = tmp_path / "tmpXYZ123" + tmp_file.write_bytes(b"x" * 1024) + + monitor.stop_event.clear() + # Run one iteration + monitor.monitor_file_size.__func__ if hasattr(monitor.monitor_file_size, '__func__') else None + + # Manually call internal logic once + emitted = [] + original_emit = monitor.progress.emit + monitor.progress.emit = lambda x: emitted.append(x) + + import buzz.model_loader as ml + old_root = ml.model_root_dir + ml.model_root_dir = str(tmp_path) + try: + monitor.stop_event.set() # stop after one iteration + monitor.stop_event.clear() + # call once manually by running the loop body + for filename in os.listdir(str(tmp_path)): + if filename.startswith("tmp"): + file_size = os.path.getsize(os.path.join(str(tmp_path), filename)) + monitor.progress.emit((file_size, monitor.total_file_size)) + assert len(emitted) > 0 + assert emitted[0][0] == 1024 + finally: + ml.model_root_dir = old_root + + def test_emits_progress_for_incomplete_files(self, tmp_path): + monitor = self._make_monitor(tmp_path) + + blobs_dir = tmp_path / "blobs" + blobs_dir.mkdir() + incomplete_file = blobs_dir / "somefile.incomplete" + incomplete_file.write_bytes(b"y" * 2048) + + emitted = [] + monitor.incomplete_download_root = str(blobs_dir) + monitor.progress.emit = lambda x: emitted.append(x) + + for filename in os.listdir(str(blobs_dir)): + if filename.endswith(".incomplete"): + file_size = os.path.getsize(os.path.join(str(blobs_dir), filename)) + monitor.progress.emit((file_size, monitor.total_file_size)) + + assert len(emitted) > 0 + assert emitted[0][0] == 2048 + + def test_stop_monitoring_emits_100_percent(self, tmp_path): + monitor = self._make_monitor(tmp_path) + monitor.monitor_thread = MagicMock() + monitor.stop_monitoring() + monitor.progress.emit.assert_called_with( + (monitor.total_file_size, monitor.total_file_size) + ) + + +class TestModelDownloaderDownloadModel: + def _make_downloader(self, model): + downloader = ModelDownloader(model=model) + downloader.signals = MagicMock() + downloader.signals.progress = MagicMock() + downloader.signals.progress.emit = MagicMock() + return downloader + + def test_download_model_fresh_success(self, tmp_path): + model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY) + downloader = self._make_downloader(model) + + file_path = str(tmp_path / "model.pt") + fake_content = b"fake model data" * 100 + + mock_response = MagicMock() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_response.status_code = 200 + mock_response.headers = {"Content-Length": str(len(fake_content))} + mock_response.iter_content = MagicMock(return_value=[fake_content]) + mock_response.raise_for_status = MagicMock() + + with patch('requests.get', return_value=mock_response), \ + patch('requests.head') as mock_head: + result = downloader.download_model(url="http://example.com/model.pt", file_path=file_path, expected_sha256=None) + + assert result is True + assert os.path.exists(file_path) + assert open(file_path, 'rb').read() == fake_content + + def test_download_model_already_downloaded_sha256_match(self, tmp_path): + import hashlib + content = b"complete model content" + sha256 = hashlib.sha256(content).hexdigest() + model_file = tmp_path / "model.pt" + model_file.write_bytes(content) + + model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY) + downloader = self._make_downloader(model) + + mock_head = MagicMock() + mock_head.headers = {"Content-Length": str(len(content)), "Accept-Ranges": "bytes"} + mock_head.raise_for_status = MagicMock() + + with patch('requests.head', return_value=mock_head): + result = downloader.download_model( + url="http://example.com/model.pt", + file_path=str(model_file), + expected_sha256=sha256 + ) + + assert result is True + + def test_download_model_sha256_mismatch_redownloads(self, tmp_path): + import hashlib + content = b"complete model content" + bad_sha256 = "0" * 64 + model_file = tmp_path / "model.pt" + model_file.write_bytes(content) + + model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY) + downloader = self._make_downloader(model) + + new_content = b"new model data" + mock_head = MagicMock() + mock_head.headers = {"Content-Length": str(len(content)), "Accept-Ranges": "bytes"} + mock_head.raise_for_status = MagicMock() + + mock_response = MagicMock() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_response.status_code = 200 + mock_response.headers = {"Content-Length": str(len(new_content))} + mock_response.iter_content = MagicMock(return_value=[new_content]) + mock_response.raise_for_status = MagicMock() + + with patch('requests.head', return_value=mock_head), \ + patch('requests.get', return_value=mock_response): + with pytest.raises(RuntimeError, match="SHA256 checksum does not match"): + downloader.download_model( + url="http://example.com/model.pt", + file_path=str(model_file), + expected_sha256=bad_sha256 + ) + + # File is deleted after SHA256 mismatch + assert not model_file.exists() + + def test_download_model_stopped_mid_download(self, tmp_path): + model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY) + downloader = self._make_downloader(model) + downloader.stopped = True + + file_path = str(tmp_path / "model.pt") + + def iter_content_gen(chunk_size): + yield b"chunk1" + + mock_response = MagicMock() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_response.status_code = 200 + mock_response.headers = {"Content-Length": "6"} + mock_response.iter_content = iter_content_gen + mock_response.raise_for_status = MagicMock() + + with patch('requests.get', return_value=mock_response): + result = downloader.download_model( + url="http://example.com/model.pt", + file_path=file_path, + expected_sha256=None + ) + + assert result is False + + def test_download_model_resumes_partial(self, tmp_path): + model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY) + downloader = self._make_downloader(model) + + existing_content = b"partial" + model_file = tmp_path / "model.pt" + model_file.write_bytes(existing_content) + resume_content = b" completed" + total_size = len(existing_content) + len(resume_content) + + mock_head_size = MagicMock() + mock_head_size.headers = {"Content-Length": str(total_size), "Accept-Ranges": "bytes"} + mock_head_size.raise_for_status = MagicMock() + + mock_head_range = MagicMock() + mock_head_range.headers = {"Accept-Ranges": "bytes"} + mock_head_range.raise_for_status = MagicMock() + + mock_response = MagicMock() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_response.status_code = 206 + mock_response.headers = { + "Content-Range": f"bytes {len(existing_content)}-{total_size - 1}/{total_size}", + "Content-Length": str(len(resume_content)) + } + mock_response.iter_content = MagicMock(return_value=[resume_content]) + mock_response.raise_for_status = MagicMock() + + with patch('requests.head', side_effect=[mock_head_size, mock_head_range]), \ + patch('requests.get', return_value=mock_response): + result = downloader.download_model( + url="http://example.com/model.pt", + file_path=str(model_file), + expected_sha256=None + ) + + assert result is True + assert open(str(model_file), 'rb').read() == existing_content + resume_content + + class TestModelLoaderCertifiImportError: def test_certifi_import_error_path(self): """Test that module handles certifi ImportError gracefully by reimporting with mock""" diff --git a/tests/transcriber/file_transcriber_queue_worker_test.py b/tests/transcriber/file_transcriber_queue_worker_test.py index 52b5d2b9..03936bc8 100644 --- a/tests/transcriber/file_transcriber_queue_worker_test.py +++ b/tests/transcriber/file_transcriber_queue_worker_test.py @@ -242,6 +242,100 @@ class TestFileTranscriberQueueWorker: error_spy.assert_not_called() +class TestFileTranscriberQueueWorkerRun: + def _make_task(self, model_type=ModelType.WHISPER_CPP, extract_speech=False): + options = TranscriptionOptions( + model=TranscriptionModel(model_type=model_type, whisper_model_size=WhisperModelSize.TINY), + extract_speech=extract_speech + ) + return FileTranscriptionTask( + file_path=str(test_multibyte_utf8_audio_path), + transcription_options=options, + file_transcription_options=FileTranscriptionOptions(), + model_path="mock_path" + ) + + def test_run_returns_early_when_already_running(self, simple_worker): + simple_worker.is_running = True + # Should return without blocking (queue is empty, no get() call) + simple_worker.run() + # is_running stays True, nothing changed + assert simple_worker.is_running is True + + def test_run_stops_on_sentinel(self, simple_worker, qapp): + completed_spy = unittest.mock.Mock() + simple_worker.completed.connect(completed_spy) + + simple_worker.tasks_queue.put(None) + simple_worker.run() + + completed_spy.assert_called_once() + assert simple_worker.is_running is False + + def test_run_skips_canceled_task_then_stops_on_sentinel(self, simple_worker, qapp): + task = self._make_task() + simple_worker.canceled_tasks.add(task.uid) + + started_spy = unittest.mock.Mock() + simple_worker.task_started.connect(started_spy) + + # Put canceled task then sentinel + simple_worker.tasks_queue.put(task) + simple_worker.tasks_queue.put(None) + + simple_worker.run() + + # Canceled task should be skipped; completed emitted + started_spy.assert_not_called() + assert simple_worker.is_running is False + + def test_run_creates_openai_transcriber(self, simple_worker, qapp): + from buzz.transcriber.openai_whisper_api_file_transcriber import OpenAIWhisperAPIFileTranscriber + + task = self._make_task(model_type=ModelType.OPEN_AI_WHISPER_API) + simple_worker.tasks_queue.put(task) + + with unittest.mock.patch.object(OpenAIWhisperAPIFileTranscriber, 'run'), \ + unittest.mock.patch.object(OpenAIWhisperAPIFileTranscriber, 'moveToThread'), \ + unittest.mock.patch('buzz.file_transcriber_queue_worker.QThread') as mock_thread_class: + mock_thread = unittest.mock.MagicMock() + mock_thread_class.return_value = mock_thread + + simple_worker.run() + + assert isinstance(simple_worker.current_transcriber, OpenAIWhisperAPIFileTranscriber) + + def test_run_creates_whisper_transcriber_for_whisper_cpp(self, simple_worker, qapp): + task = self._make_task(model_type=ModelType.WHISPER_CPP) + simple_worker.tasks_queue.put(task) + + with unittest.mock.patch.object(WhisperFileTranscriber, 'run'), \ + unittest.mock.patch.object(WhisperFileTranscriber, 'moveToThread'), \ + unittest.mock.patch('buzz.file_transcriber_queue_worker.QThread') as mock_thread_class: + mock_thread = unittest.mock.MagicMock() + mock_thread_class.return_value = mock_thread + + simple_worker.run() + + assert isinstance(simple_worker.current_transcriber, WhisperFileTranscriber) + + def test_run_speech_extraction_failure_emits_error(self, simple_worker, qapp): + task = self._make_task(extract_speech=True) + simple_worker.tasks_queue.put(task) + + error_spy = unittest.mock.Mock() + simple_worker.task_error.connect(error_spy) + + with unittest.mock.patch('buzz.file_transcriber_queue_worker.demucsApi.Separator', + side_effect=RuntimeError("No internet")): + simple_worker.run() + + error_spy.assert_called_once() + args = error_spy.call_args[0] + assert args[0] == task + assert simple_worker.is_running is False + + def test_transcription_with_whisper_cpp_tiny_no_speech_extraction(worker): options = TranscriptionOptions( model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY), diff --git a/tests/widgets/hugging_face_search_line_edit_test.py b/tests/widgets/hugging_face_search_line_edit_test.py new file mode 100644 index 00000000..d9a5b312 --- /dev/null +++ b/tests/widgets/hugging_face_search_line_edit_test.py @@ -0,0 +1,177 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +from PyQt6.QtCore import Qt, QEvent, QPoint +from PyQt6.QtGui import QKeyEvent +from PyQt6.QtNetwork import QNetworkReply, QNetworkAccessManager +from PyQt6.QtWidgets import QListWidgetItem +from pytestqt.qtbot import QtBot + +from buzz.widgets.transcriber.hugging_face_search_line_edit import HuggingFaceSearchLineEdit + + +@pytest.fixture +def widget(qtbot: QtBot): + mock_manager = MagicMock(spec=QNetworkAccessManager) + mock_manager.finished = MagicMock() + mock_manager.finished.connect = MagicMock() + w = HuggingFaceSearchLineEdit(network_access_manager=mock_manager) + qtbot.add_widget(w) + # Prevent popup.show() from triggering a Wayland fatal protocol error + # in headless/CI environments where popup windows lack a transient parent. + w.popup.show = MagicMock() + return w + + +class TestHuggingFaceSearchLineEdit: + def test_initial_state(self, widget): + assert widget.text() == "" + assert widget.placeholderText() != "" + + def test_default_value_set(self, qtbot: QtBot): + mock_manager = MagicMock(spec=QNetworkAccessManager) + mock_manager.finished = MagicMock() + mock_manager.finished.connect = MagicMock() + w = HuggingFaceSearchLineEdit(default_value="openai/whisper-tiny", network_access_manager=mock_manager) + qtbot.add_widget(w) + assert w.text() == "openai/whisper-tiny" + + def test_on_text_edited_emits_model_selected(self, widget, qtbot: QtBot): + spy = MagicMock() + widget.model_selected.connect(spy) + widget.on_text_edited("some/model") + spy.assert_called_once_with("some/model") + + def test_fetch_models_skips_short_text(self, widget): + widget.setText("ab") + result = widget.fetch_models() + assert result is None + + def test_fetch_models_makes_request_for_long_text(self, widget): + widget.setText("whisper-tiny") + mock_reply = MagicMock() + widget.network_manager.get = MagicMock(return_value=mock_reply) + result = widget.fetch_models() + widget.network_manager.get.assert_called_once() + assert result == mock_reply + + def test_fetch_models_url_contains_search_text(self, widget): + widget.setText("whisper") + widget.network_manager.get = MagicMock(return_value=MagicMock()) + widget.fetch_models() + call_args = widget.network_manager.get.call_args[0][0] + assert "whisper" in call_args.url().toString() + + def test_on_request_response_network_error_does_not_populate_popup(self, widget): + mock_reply = MagicMock(spec=QNetworkReply) + mock_reply.error.return_value = QNetworkReply.NetworkError.ConnectionRefusedError + widget.on_request_response(mock_reply) + assert widget.popup.count() == 0 + + def test_on_request_response_populates_popup(self, widget): + mock_reply = MagicMock(spec=QNetworkReply) + mock_reply.error.return_value = QNetworkReply.NetworkError.NoError + models = [{"id": "openai/whisper-tiny"}, {"id": "openai/whisper-base"}] + mock_reply.readAll.return_value.data.return_value = json.dumps(models).encode() + widget.on_request_response(mock_reply) + assert widget.popup.count() == 2 + assert widget.popup.item(0).text() == "openai/whisper-tiny" + assert widget.popup.item(1).text() == "openai/whisper-base" + + def test_on_request_response_empty_models_does_not_show_popup(self, widget): + mock_reply = MagicMock(spec=QNetworkReply) + mock_reply.error.return_value = QNetworkReply.NetworkError.NoError + mock_reply.readAll.return_value.data.return_value = json.dumps([]).encode() + widget.on_request_response(mock_reply) + assert widget.popup.count() == 0 + widget.popup.show.assert_not_called() + + def test_on_request_response_item_has_user_role_data(self, widget): + mock_reply = MagicMock(spec=QNetworkReply) + mock_reply.error.return_value = QNetworkReply.NetworkError.NoError + models = [{"id": "facebook/mms-1b-all"}] + mock_reply.readAll.return_value.data.return_value = json.dumps(models).encode() + widget.on_request_response(mock_reply) + item = widget.popup.item(0) + assert item.data(Qt.ItemDataRole.UserRole) == "facebook/mms-1b-all" + + def test_on_select_item_emits_model_selected(self, widget, qtbot: QtBot): + item = QListWidgetItem("openai/whisper-tiny") + item.setData(Qt.ItemDataRole.UserRole, "openai/whisper-tiny") + widget.popup.addItem(item) + widget.popup.setCurrentItem(item) + + spy = MagicMock() + widget.model_selected.connect(spy) + widget.on_select_item() + + spy.assert_called_with("openai/whisper-tiny") + assert widget.text() == "openai/whisper-tiny" + + def test_on_select_item_hides_popup(self, widget): + item = QListWidgetItem("openai/whisper-tiny") + item.setData(Qt.ItemDataRole.UserRole, "openai/whisper-tiny") + widget.popup.addItem(item) + widget.popup.setCurrentItem(item) + + with patch.object(widget.popup, 'hide') as mock_hide: + widget.on_select_item() + mock_hide.assert_called_once() + + def test_on_popup_selected_stops_timer(self, widget): + widget.timer.start() + assert widget.timer.isActive() + widget.on_popup_selected() + assert not widget.timer.isActive() + + def test_event_filter_ignores_non_popup_target(self, widget): + other = MagicMock() + event = MagicMock() + assert widget.eventFilter(other, event) is False + + def test_event_filter_mouse_press_hides_popup(self, widget): + event = MagicMock() + event.type.return_value = QEvent.Type.MouseButtonPress + with patch.object(widget.popup, 'hide') as mock_hide: + result = widget.eventFilter(widget.popup, event) + assert result is True + mock_hide.assert_called_once() + + def test_event_filter_escape_hides_popup(self, widget, qtbot: QtBot): + event = QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_Escape, Qt.KeyboardModifier.NoModifier) + with patch.object(widget.popup, 'hide') as mock_hide: + result = widget.eventFilter(widget.popup, event) + assert result is True + mock_hide.assert_called_once() + + def test_event_filter_enter_selects_item(self, widget, qtbot: QtBot): + item = QListWidgetItem("openai/whisper-tiny") + item.setData(Qt.ItemDataRole.UserRole, "openai/whisper-tiny") + widget.popup.addItem(item) + widget.popup.setCurrentItem(item) + + spy = MagicMock() + widget.model_selected.connect(spy) + + event = QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_Return, Qt.KeyboardModifier.NoModifier) + result = widget.eventFilter(widget.popup, event) + assert result is True + spy.assert_called_with("openai/whisper-tiny") + + def test_event_filter_enter_no_item_returns_true(self, widget, qtbot: QtBot): + event = QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_Return, Qt.KeyboardModifier.NoModifier) + result = widget.eventFilter(widget.popup, event) + assert result is True + + def test_event_filter_navigation_keys_return_false(self, widget): + for key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, + Qt.Key.Key_End, Qt.Key.Key_PageUp, Qt.Key.Key_PageDown]: + event = QKeyEvent(QEvent.Type.KeyPress, key, Qt.KeyboardModifier.NoModifier) + assert widget.eventFilter(widget.popup, event) is False + + def test_event_filter_other_key_hides_popup(self, widget): + event = QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_A, Qt.KeyboardModifier.NoModifier) + with patch.object(widget.popup, 'hide') as mock_hide: + widget.eventFilter(widget.popup, event) + mock_hide.assert_called_once() diff --git a/tests/widgets/speaker_identification_widget_test.py b/tests/widgets/speaker_identification_widget_test.py index 5b65514d..fc97a819 100644 --- a/tests/widgets/speaker_identification_widget_test.py +++ b/tests/widgets/speaker_identification_widget_test.py @@ -90,6 +90,217 @@ class TestSpeakerIdentificationWidget: assert (result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bien venue dans. '}]] or result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bienvenue dans. '}]]) + def test_identify_button_toggles_visibility(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + # Before: identify visible, cancel hidden + assert not widget.step_1_button.isHidden() + assert widget.cancel_button.isHidden() + + from PyQt6.QtCore import QThread as RealQThread + mock_thread = MagicMock(spec=RealQThread) + mock_thread.started = MagicMock() + mock_thread.started.connect = MagicMock() + + with patch.object(widget, '_cleanup_thread'), \ + patch('buzz.widgets.transcription_viewer.speaker_identification_widget.QThread', return_value=mock_thread), \ + patch.object(widget, 'worker', create=True): + # patch moveToThread on IdentificationWorker to avoid type error + with patch.object(IdentificationWorker, 'moveToThread'): + widget.on_identify_button_clicked() + + # After: identify hidden, cancel visible + assert widget.step_1_button.isHidden() + assert not widget.cancel_button.isHidden() + + widget.close() + + def test_cancel_button_resets_ui(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + # Simulate identification started + widget.step_1_button.setVisible(False) + widget.cancel_button.setVisible(True) + + with patch.object(widget, '_cleanup_thread'): + widget.on_cancel_button_clicked() + + assert not widget.step_1_button.isHidden() + assert widget.cancel_button.isHidden() + assert widget.progress_bar.value() == 0 + assert len(widget.progress_label.text()) > 0 + + widget.close() + + def test_on_progress_update_sets_label_and_bar(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + widget.on_progress_update("3/8 Loading alignment model") + + assert widget.progress_label.text() == "3/8 Loading alignment model" + assert widget.progress_bar.value() == 3 + + widget.close() + + def test_on_progress_update_step_8_enables_save(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + assert not widget.save_button.isEnabled() + + widget.on_progress_update("8/8 Identification done") + + assert widget.save_button.isEnabled() + assert widget.step_2_group_box.isEnabled() + assert widget.merge_speaker_sentences.isEnabled() + + widget.close() + + def test_on_identification_finished_empty_result(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + initial_row_count = widget.speaker_preview_row.count() + + widget.on_identification_finished([]) + + assert widget.identification_result == [] + # Empty result returns early — speaker preview row unchanged + assert widget.speaker_preview_row.count() == initial_row_count + + widget.close() + + def test_on_identification_finished_populates_speakers(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + result = [ + {'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 3000, 'text': 'Hello world.'}, + {'speaker': 'Speaker 1', 'start_time': 3000, 'end_time': 6000, 'text': 'Hi there.'}, + ] + widget.on_identification_finished(result) + + assert widget.identification_result == result + # Two speaker rows should have been created + assert widget.speaker_preview_row.count() == 2 + + widget.close() + + def test_on_identification_error_resets_buttons(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + widget.step_1_button.setVisible(False) + widget.cancel_button.setVisible(True) + + widget.on_identification_error("Some error") + + assert not widget.step_1_button.isHidden() + assert widget.cancel_button.isHidden() + assert widget.progress_bar.value() == 0 + + widget.close() + + def test_on_save_no_merge(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + result = [ + {'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 2000, 'text': 'Hello.'}, + {'speaker': 'Speaker 0', 'start_time': 2000, 'end_time': 4000, 'text': 'World.'}, + {'speaker': 'Speaker 1', 'start_time': 4000, 'end_time': 6000, 'text': 'Hi.'}, + ] + widget.on_identification_finished(result) + widget.merge_speaker_sentences.setChecked(False) + + with patch.object(widget.transcription_service, 'copy_transcription', return_value=uuid.uuid4()) as mock_copy, \ + patch.object(widget.transcription_service, 'update_transcription_as_completed') as mock_update: + widget.on_save_button_clicked() + + mock_copy.assert_called_once() + mock_update.assert_called_once() + segments = mock_update.call_args[0][1] + # No merge: 3 entries → 3 segments + assert len(segments) == 3 + + widget.close() + + def test_on_save_with_merge(self, qtbot: QtBot, transcription, transcription_service): + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + ) + qtbot.addWidget(widget) + + result = [ + {'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 2000, 'text': 'Hello.'}, + {'speaker': 'Speaker 0', 'start_time': 2000, 'end_time': 4000, 'text': 'World.'}, + {'speaker': 'Speaker 1', 'start_time': 4000, 'end_time': 6000, 'text': 'Hi.'}, + ] + widget.on_identification_finished(result) + widget.merge_speaker_sentences.setChecked(True) + + with patch.object(widget.transcription_service, 'copy_transcription', return_value=uuid.uuid4()), \ + patch.object(widget.transcription_service, 'update_transcription_as_completed') as mock_update: + widget.on_save_button_clicked() + + segments = mock_update.call_args[0][1] + # Merge: two consecutive Speaker 0 entries → merged into 1; Speaker 1 → 1 = 2 total + assert len(segments) == 2 + assert "Speaker 0" in segments[0].text + assert "Hello." in segments[0].text + assert "World." in segments[0].text + + widget.close() + + def test_on_save_emits_transcriptions_updated(self, qtbot: QtBot, transcription, transcription_service): + updated_signal = MagicMock() + widget = SpeakerIdentificationWidget( + transcription=transcription, + transcription_service=transcription_service, + transcriptions_updated_signal=updated_signal, + ) + qtbot.addWidget(widget) + + result = [{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 1000, 'text': 'Hi.'}] + widget.on_identification_finished(result) + + new_id = uuid.uuid4() + with patch.object(widget.transcription_service, 'copy_transcription', return_value=new_id), \ + patch.object(widget.transcription_service, 'update_transcription_as_completed'): + widget.on_save_button_clicked() + + updated_signal.emit.assert_called_once_with(new_id) + + widget.close() + def test_batch_processing_with_many_words(self): """Test batch processing when there are more than 200 words.""" # Create mock punctuation model