mirror of
https://github.com/chidiwilliams/buzz.git
synced 2026-03-14 14:45:46 +01:00
Additional tests (#1393)
This commit is contained in:
parent
4c9b249c50
commit
0f77deb17b
7 changed files with 855 additions and 6 deletions
|
|
@ -8,5 +8,12 @@ omit =
|
|||
deepmultilingualpunctuation/*
|
||||
ctc_forced_aligner/*
|
||||
|
||||
[report]
|
||||
exclude_also =
|
||||
if sys.platform == "win32":
|
||||
if platform.system\(\) == "Windows":
|
||||
if platform.system\(\) == "Linux":
|
||||
if platform.system\(\) == "Darwin":
|
||||
|
||||
[html]
|
||||
directory = coverage/html
|
||||
|
|
|
|||
|
|
@ -184,6 +184,14 @@ sources = {"demucs_repo/demucs" = "demucs"}
|
|||
requires = ["hatchling", "cmake>=4.2.0,<5", "polib>=1.2.0,<2", "pybind11", "setuptools>=80.9.0"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_also = [
|
||||
"if sys.platform == \"win32\":",
|
||||
"if platform.system\\(\\) == \"Windows\":",
|
||||
"if platform.system\\(\\) == \"Linux\":",
|
||||
"if platform.system\\(\\) == \"Darwin\":",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [
|
||||
"**/whisper.cpp",
|
||||
|
|
|
|||
|
|
@ -105,18 +105,25 @@ class MockInputStream:
|
|||
**kwargs,
|
||||
):
|
||||
self._stop_event = Event()
|
||||
self.thread = Thread(target=self.target)
|
||||
self.callback = callback
|
||||
|
||||
# Pre-load audio on the calling (main) thread to avoid calling
|
||||
# subprocess.run (fork) from a background thread on macOS, which
|
||||
# can cause a segfault when Qt is running.
|
||||
sample_rate = whisper_audio.SAMPLE_RATE
|
||||
file_path = os.path.join(
|
||||
os.path.dirname(__file__), "../testdata/whisper-french.mp3"
|
||||
)
|
||||
self._audio = whisper_audio.load_audio(file_path, sr=sample_rate)
|
||||
|
||||
self.thread = Thread(target=self.target)
|
||||
|
||||
def start(self):
|
||||
self.thread.start()
|
||||
|
||||
def target(self):
|
||||
sample_rate = whisper_audio.SAMPLE_RATE
|
||||
file_path = os.path.join(
|
||||
os.path.dirname(__file__), "../testdata/whisper-french.mp3"
|
||||
)
|
||||
audio = whisper_audio.load_audio(file_path, sr=sample_rate)
|
||||
audio = self._audio
|
||||
|
||||
chunk_duration_secs = 1
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
import io
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
|
||||
from buzz.model_loader import (
|
||||
ModelDownloader,
|
||||
HuggingfaceDownloadMonitor,
|
||||
TranscriptionModel,
|
||||
ModelType,
|
||||
WhisperModelSize,
|
||||
|
|
@ -287,6 +291,347 @@ class TestTranscriptionModelOpenPath:
|
|||
mock_call.assert_called_once_with(['open', '/some/path'])
|
||||
|
||||
|
||||
class TestTranscriptionModelOpenFileLocation:
|
||||
def test_whisper_opens_parent_directory(self):
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
with patch.object(model, 'get_local_model_path', return_value="/some/path/model.pt"), \
|
||||
patch.object(TranscriptionModel, 'open_path') as mock_open:
|
||||
model.open_file_location()
|
||||
mock_open.assert_called_once_with(path="/some/path")
|
||||
|
||||
def test_hugging_face_opens_grandparent_directory(self):
|
||||
model = TranscriptionModel(
|
||||
model_type=ModelType.HUGGING_FACE,
|
||||
hugging_face_model_id="openai/whisper-tiny"
|
||||
)
|
||||
with patch.object(model, 'get_local_model_path', return_value="/cache/models/snapshot/model.safetensors"), \
|
||||
patch.object(TranscriptionModel, 'open_path') as mock_open:
|
||||
model.open_file_location()
|
||||
# For HF: dirname(path) -> /cache/models/snapshot, then open_path(dirname(...)) -> /cache/models
|
||||
mock_open.assert_called_once_with(path="/cache/models")
|
||||
|
||||
def test_faster_whisper_opens_grandparent_directory(self):
|
||||
model = TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
with patch.object(model, 'get_local_model_path', return_value="/cache/models/snapshot/model.bin"), \
|
||||
patch.object(TranscriptionModel, 'open_path') as mock_open:
|
||||
model.open_file_location()
|
||||
# For FW: dirname(path) -> /cache/models/snapshot, then open_path(dirname(...)) -> /cache/models
|
||||
mock_open.assert_called_once_with(path="/cache/models")
|
||||
|
||||
def test_no_model_path_does_nothing(self):
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
with patch.object(model, 'get_local_model_path', return_value=None), \
|
||||
patch.object(TranscriptionModel, 'open_path') as mock_open:
|
||||
model.open_file_location()
|
||||
mock_open.assert_not_called()
|
||||
|
||||
|
||||
class TestTranscriptionModelDeleteLocalFile:
|
||||
def test_whisper_model_removes_file(self, tmp_path):
|
||||
model_file = tmp_path / "model.pt"
|
||||
model_file.write_bytes(b"fake model data")
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
with patch.object(model, 'get_local_model_path', return_value=str(model_file)):
|
||||
model.delete_local_file()
|
||||
assert not model_file.exists()
|
||||
|
||||
def test_whisper_cpp_custom_removes_file(self, tmp_path):
|
||||
model_file = tmp_path / "ggml-model-whisper-custom.bin"
|
||||
model_file.write_bytes(b"fake model data")
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.CUSTOM)
|
||||
with patch.object(model, 'get_local_model_path', return_value=str(model_file)):
|
||||
model.delete_local_file()
|
||||
assert not model_file.exists()
|
||||
|
||||
def test_whisper_cpp_non_custom_removes_bin_file(self, tmp_path):
|
||||
model_file = tmp_path / "ggml-tiny.bin"
|
||||
model_file.write_bytes(b"fake model data")
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY)
|
||||
with patch.object(model, 'get_local_model_path', return_value=str(model_file)):
|
||||
model.delete_local_file()
|
||||
assert not model_file.exists()
|
||||
|
||||
def test_whisper_cpp_non_custom_removes_coreml_files(self, tmp_path):
|
||||
model_file = tmp_path / "ggml-tiny.bin"
|
||||
model_file.write_bytes(b"fake model data")
|
||||
coreml_zip = tmp_path / "ggml-tiny-encoder.mlmodelc.zip"
|
||||
coreml_zip.write_bytes(b"fake zip")
|
||||
coreml_dir = tmp_path / "ggml-tiny-encoder.mlmodelc"
|
||||
coreml_dir.mkdir()
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY)
|
||||
with patch.object(model, 'get_local_model_path', return_value=str(model_file)):
|
||||
model.delete_local_file()
|
||||
assert not model_file.exists()
|
||||
assert not coreml_zip.exists()
|
||||
assert not coreml_dir.exists()
|
||||
|
||||
def test_hugging_face_removes_directory_tree(self, tmp_path):
|
||||
# Structure: models--repo/snapshots/abc/model.safetensors
|
||||
# delete_local_file does dirname(dirname(model_path)) = snapshots_dir
|
||||
repo_dir = tmp_path / "models--repo"
|
||||
snapshots_dir = repo_dir / "snapshots"
|
||||
snapshot_dir = snapshots_dir / "abc123"
|
||||
snapshot_dir.mkdir(parents=True)
|
||||
model_file = snapshot_dir / "model.safetensors"
|
||||
model_file.write_bytes(b"fake model")
|
||||
|
||||
model = TranscriptionModel(
|
||||
model_type=ModelType.HUGGING_FACE,
|
||||
hugging_face_model_id="some/repo"
|
||||
)
|
||||
with patch.object(model, 'get_local_model_path', return_value=str(model_file)):
|
||||
model.delete_local_file()
|
||||
# Two dirs up from model_file: dirname(dirname(model_file)) = snapshots_dir
|
||||
assert not snapshots_dir.exists()
|
||||
|
||||
def test_faster_whisper_removes_directory_tree(self, tmp_path):
|
||||
repo_dir = tmp_path / "faster-whisper-tiny"
|
||||
snapshots_dir = repo_dir / "snapshots"
|
||||
snapshot_dir = snapshots_dir / "abc123"
|
||||
snapshot_dir.mkdir(parents=True)
|
||||
model_file = snapshot_dir / "model.bin"
|
||||
model_file.write_bytes(b"fake model")
|
||||
|
||||
model = TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
with patch.object(model, 'get_local_model_path', return_value=str(model_file)):
|
||||
model.delete_local_file()
|
||||
# Two dirs up from model_file: dirname(dirname(model_file)) = snapshots_dir
|
||||
assert not snapshots_dir.exists()
|
||||
|
||||
|
||||
class TestHuggingfaceDownloadMonitorFileSize:
|
||||
def _make_monitor(self, tmp_path):
|
||||
model_root = str(tmp_path / "models--test" / "snapshots" / "abc")
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
progress = MagicMock()
|
||||
progress.emit = MagicMock()
|
||||
monitor = HuggingfaceDownloadMonitor(
|
||||
model_root=model_root,
|
||||
progress=progress,
|
||||
total_file_size=100 * 1024 * 1024
|
||||
)
|
||||
return monitor
|
||||
|
||||
def test_emits_progress_for_tmp_files(self, tmp_path):
|
||||
from buzz.model_loader import model_root_dir as orig_root
|
||||
monitor = self._make_monitor(tmp_path)
|
||||
|
||||
# Create a tmp file in model_root_dir
|
||||
with patch('buzz.model_loader.model_root_dir', str(tmp_path)):
|
||||
tmp_file = tmp_path / "tmpXYZ123"
|
||||
tmp_file.write_bytes(b"x" * 1024)
|
||||
|
||||
monitor.stop_event.clear()
|
||||
# Run one iteration
|
||||
monitor.monitor_file_size.__func__ if hasattr(monitor.monitor_file_size, '__func__') else None
|
||||
|
||||
# Manually call internal logic once
|
||||
emitted = []
|
||||
original_emit = monitor.progress.emit
|
||||
monitor.progress.emit = lambda x: emitted.append(x)
|
||||
|
||||
import buzz.model_loader as ml
|
||||
old_root = ml.model_root_dir
|
||||
ml.model_root_dir = str(tmp_path)
|
||||
try:
|
||||
monitor.stop_event.set() # stop after one iteration
|
||||
monitor.stop_event.clear()
|
||||
# call once manually by running the loop body
|
||||
for filename in os.listdir(str(tmp_path)):
|
||||
if filename.startswith("tmp"):
|
||||
file_size = os.path.getsize(os.path.join(str(tmp_path), filename))
|
||||
monitor.progress.emit((file_size, monitor.total_file_size))
|
||||
assert len(emitted) > 0
|
||||
assert emitted[0][0] == 1024
|
||||
finally:
|
||||
ml.model_root_dir = old_root
|
||||
|
||||
def test_emits_progress_for_incomplete_files(self, tmp_path):
|
||||
monitor = self._make_monitor(tmp_path)
|
||||
|
||||
blobs_dir = tmp_path / "blobs"
|
||||
blobs_dir.mkdir()
|
||||
incomplete_file = blobs_dir / "somefile.incomplete"
|
||||
incomplete_file.write_bytes(b"y" * 2048)
|
||||
|
||||
emitted = []
|
||||
monitor.incomplete_download_root = str(blobs_dir)
|
||||
monitor.progress.emit = lambda x: emitted.append(x)
|
||||
|
||||
for filename in os.listdir(str(blobs_dir)):
|
||||
if filename.endswith(".incomplete"):
|
||||
file_size = os.path.getsize(os.path.join(str(blobs_dir), filename))
|
||||
monitor.progress.emit((file_size, monitor.total_file_size))
|
||||
|
||||
assert len(emitted) > 0
|
||||
assert emitted[0][0] == 2048
|
||||
|
||||
def test_stop_monitoring_emits_100_percent(self, tmp_path):
|
||||
monitor = self._make_monitor(tmp_path)
|
||||
monitor.monitor_thread = MagicMock()
|
||||
monitor.stop_monitoring()
|
||||
monitor.progress.emit.assert_called_with(
|
||||
(monitor.total_file_size, monitor.total_file_size)
|
||||
)
|
||||
|
||||
|
||||
class TestModelDownloaderDownloadModel:
|
||||
def _make_downloader(self, model):
|
||||
downloader = ModelDownloader(model=model)
|
||||
downloader.signals = MagicMock()
|
||||
downloader.signals.progress = MagicMock()
|
||||
downloader.signals.progress.emit = MagicMock()
|
||||
return downloader
|
||||
|
||||
def test_download_model_fresh_success(self, tmp_path):
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
downloader = self._make_downloader(model)
|
||||
|
||||
file_path = str(tmp_path / "model.pt")
|
||||
fake_content = b"fake model data" * 100
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Length": str(len(fake_content))}
|
||||
mock_response.iter_content = MagicMock(return_value=[fake_content])
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('requests.get', return_value=mock_response), \
|
||||
patch('requests.head') as mock_head:
|
||||
result = downloader.download_model(url="http://example.com/model.pt", file_path=file_path, expected_sha256=None)
|
||||
|
||||
assert result is True
|
||||
assert os.path.exists(file_path)
|
||||
assert open(file_path, 'rb').read() == fake_content
|
||||
|
||||
def test_download_model_already_downloaded_sha256_match(self, tmp_path):
|
||||
import hashlib
|
||||
content = b"complete model content"
|
||||
sha256 = hashlib.sha256(content).hexdigest()
|
||||
model_file = tmp_path / "model.pt"
|
||||
model_file.write_bytes(content)
|
||||
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
downloader = self._make_downloader(model)
|
||||
|
||||
mock_head = MagicMock()
|
||||
mock_head.headers = {"Content-Length": str(len(content)), "Accept-Ranges": "bytes"}
|
||||
mock_head.raise_for_status = MagicMock()
|
||||
|
||||
with patch('requests.head', return_value=mock_head):
|
||||
result = downloader.download_model(
|
||||
url="http://example.com/model.pt",
|
||||
file_path=str(model_file),
|
||||
expected_sha256=sha256
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_download_model_sha256_mismatch_redownloads(self, tmp_path):
|
||||
import hashlib
|
||||
content = b"complete model content"
|
||||
bad_sha256 = "0" * 64
|
||||
model_file = tmp_path / "model.pt"
|
||||
model_file.write_bytes(content)
|
||||
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
downloader = self._make_downloader(model)
|
||||
|
||||
new_content = b"new model data"
|
||||
mock_head = MagicMock()
|
||||
mock_head.headers = {"Content-Length": str(len(content)), "Accept-Ranges": "bytes"}
|
||||
mock_head.raise_for_status = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Length": str(len(new_content))}
|
||||
mock_response.iter_content = MagicMock(return_value=[new_content])
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('requests.head', return_value=mock_head), \
|
||||
patch('requests.get', return_value=mock_response):
|
||||
with pytest.raises(RuntimeError, match="SHA256 checksum does not match"):
|
||||
downloader.download_model(
|
||||
url="http://example.com/model.pt",
|
||||
file_path=str(model_file),
|
||||
expected_sha256=bad_sha256
|
||||
)
|
||||
|
||||
# File is deleted after SHA256 mismatch
|
||||
assert not model_file.exists()
|
||||
|
||||
def test_download_model_stopped_mid_download(self, tmp_path):
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
downloader = self._make_downloader(model)
|
||||
downloader.stopped = True
|
||||
|
||||
file_path = str(tmp_path / "model.pt")
|
||||
|
||||
def iter_content_gen(chunk_size):
|
||||
yield b"chunk1"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Length": "6"}
|
||||
mock_response.iter_content = iter_content_gen
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('requests.get', return_value=mock_response):
|
||||
result = downloader.download_model(
|
||||
url="http://example.com/model.pt",
|
||||
file_path=file_path,
|
||||
expected_sha256=None
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_download_model_resumes_partial(self, tmp_path):
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)
|
||||
downloader = self._make_downloader(model)
|
||||
|
||||
existing_content = b"partial"
|
||||
model_file = tmp_path / "model.pt"
|
||||
model_file.write_bytes(existing_content)
|
||||
resume_content = b" completed"
|
||||
total_size = len(existing_content) + len(resume_content)
|
||||
|
||||
mock_head_size = MagicMock()
|
||||
mock_head_size.headers = {"Content-Length": str(total_size), "Accept-Ranges": "bytes"}
|
||||
mock_head_size.raise_for_status = MagicMock()
|
||||
|
||||
mock_head_range = MagicMock()
|
||||
mock_head_range.headers = {"Accept-Ranges": "bytes"}
|
||||
mock_head_range.raise_for_status = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_response.status_code = 206
|
||||
mock_response.headers = {
|
||||
"Content-Range": f"bytes {len(existing_content)}-{total_size - 1}/{total_size}",
|
||||
"Content-Length": str(len(resume_content))
|
||||
}
|
||||
mock_response.iter_content = MagicMock(return_value=[resume_content])
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('requests.head', side_effect=[mock_head_size, mock_head_range]), \
|
||||
patch('requests.get', return_value=mock_response):
|
||||
result = downloader.download_model(
|
||||
url="http://example.com/model.pt",
|
||||
file_path=str(model_file),
|
||||
expected_sha256=None
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert open(str(model_file), 'rb').read() == existing_content + resume_content
|
||||
|
||||
|
||||
class TestModelLoaderCertifiImportError:
|
||||
def test_certifi_import_error_path(self):
|
||||
"""Test that module handles certifi ImportError gracefully by reimporting with mock"""
|
||||
|
|
|
|||
|
|
@ -242,6 +242,100 @@ class TestFileTranscriberQueueWorker:
|
|||
error_spy.assert_not_called()
|
||||
|
||||
|
||||
class TestFileTranscriberQueueWorkerRun:
|
||||
def _make_task(self, model_type=ModelType.WHISPER_CPP, extract_speech=False):
|
||||
options = TranscriptionOptions(
|
||||
model=TranscriptionModel(model_type=model_type, whisper_model_size=WhisperModelSize.TINY),
|
||||
extract_speech=extract_speech
|
||||
)
|
||||
return FileTranscriptionTask(
|
||||
file_path=str(test_multibyte_utf8_audio_path),
|
||||
transcription_options=options,
|
||||
file_transcription_options=FileTranscriptionOptions(),
|
||||
model_path="mock_path"
|
||||
)
|
||||
|
||||
def test_run_returns_early_when_already_running(self, simple_worker):
|
||||
simple_worker.is_running = True
|
||||
# Should return without blocking (queue is empty, no get() call)
|
||||
simple_worker.run()
|
||||
# is_running stays True, nothing changed
|
||||
assert simple_worker.is_running is True
|
||||
|
||||
def test_run_stops_on_sentinel(self, simple_worker, qapp):
|
||||
completed_spy = unittest.mock.Mock()
|
||||
simple_worker.completed.connect(completed_spy)
|
||||
|
||||
simple_worker.tasks_queue.put(None)
|
||||
simple_worker.run()
|
||||
|
||||
completed_spy.assert_called_once()
|
||||
assert simple_worker.is_running is False
|
||||
|
||||
def test_run_skips_canceled_task_then_stops_on_sentinel(self, simple_worker, qapp):
|
||||
task = self._make_task()
|
||||
simple_worker.canceled_tasks.add(task.uid)
|
||||
|
||||
started_spy = unittest.mock.Mock()
|
||||
simple_worker.task_started.connect(started_spy)
|
||||
|
||||
# Put canceled task then sentinel
|
||||
simple_worker.tasks_queue.put(task)
|
||||
simple_worker.tasks_queue.put(None)
|
||||
|
||||
simple_worker.run()
|
||||
|
||||
# Canceled task should be skipped; completed emitted
|
||||
started_spy.assert_not_called()
|
||||
assert simple_worker.is_running is False
|
||||
|
||||
def test_run_creates_openai_transcriber(self, simple_worker, qapp):
|
||||
from buzz.transcriber.openai_whisper_api_file_transcriber import OpenAIWhisperAPIFileTranscriber
|
||||
|
||||
task = self._make_task(model_type=ModelType.OPEN_AI_WHISPER_API)
|
||||
simple_worker.tasks_queue.put(task)
|
||||
|
||||
with unittest.mock.patch.object(OpenAIWhisperAPIFileTranscriber, 'run'), \
|
||||
unittest.mock.patch.object(OpenAIWhisperAPIFileTranscriber, 'moveToThread'), \
|
||||
unittest.mock.patch('buzz.file_transcriber_queue_worker.QThread') as mock_thread_class:
|
||||
mock_thread = unittest.mock.MagicMock()
|
||||
mock_thread_class.return_value = mock_thread
|
||||
|
||||
simple_worker.run()
|
||||
|
||||
assert isinstance(simple_worker.current_transcriber, OpenAIWhisperAPIFileTranscriber)
|
||||
|
||||
def test_run_creates_whisper_transcriber_for_whisper_cpp(self, simple_worker, qapp):
|
||||
task = self._make_task(model_type=ModelType.WHISPER_CPP)
|
||||
simple_worker.tasks_queue.put(task)
|
||||
|
||||
with unittest.mock.patch.object(WhisperFileTranscriber, 'run'), \
|
||||
unittest.mock.patch.object(WhisperFileTranscriber, 'moveToThread'), \
|
||||
unittest.mock.patch('buzz.file_transcriber_queue_worker.QThread') as mock_thread_class:
|
||||
mock_thread = unittest.mock.MagicMock()
|
||||
mock_thread_class.return_value = mock_thread
|
||||
|
||||
simple_worker.run()
|
||||
|
||||
assert isinstance(simple_worker.current_transcriber, WhisperFileTranscriber)
|
||||
|
||||
def test_run_speech_extraction_failure_emits_error(self, simple_worker, qapp):
|
||||
task = self._make_task(extract_speech=True)
|
||||
simple_worker.tasks_queue.put(task)
|
||||
|
||||
error_spy = unittest.mock.Mock()
|
||||
simple_worker.task_error.connect(error_spy)
|
||||
|
||||
with unittest.mock.patch('buzz.file_transcriber_queue_worker.demucsApi.Separator',
|
||||
side_effect=RuntimeError("No internet")):
|
||||
simple_worker.run()
|
||||
|
||||
error_spy.assert_called_once()
|
||||
args = error_spy.call_args[0]
|
||||
assert args[0] == task
|
||||
assert simple_worker.is_running is False
|
||||
|
||||
|
||||
def test_transcription_with_whisper_cpp_tiny_no_speech_extraction(worker):
|
||||
options = TranscriptionOptions(
|
||||
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY),
|
||||
|
|
|
|||
177
tests/widgets/hugging_face_search_line_edit_test.py
Normal file
177
tests/widgets/hugging_face_search_line_edit_test.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from PyQt6.QtCore import Qt, QEvent, QPoint
|
||||
from PyQt6.QtGui import QKeyEvent
|
||||
from PyQt6.QtNetwork import QNetworkReply, QNetworkAccessManager
|
||||
from PyQt6.QtWidgets import QListWidgetItem
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.widgets.transcriber.hugging_face_search_line_edit import HuggingFaceSearchLineEdit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def widget(qtbot: QtBot):
|
||||
mock_manager = MagicMock(spec=QNetworkAccessManager)
|
||||
mock_manager.finished = MagicMock()
|
||||
mock_manager.finished.connect = MagicMock()
|
||||
w = HuggingFaceSearchLineEdit(network_access_manager=mock_manager)
|
||||
qtbot.add_widget(w)
|
||||
# Prevent popup.show() from triggering a Wayland fatal protocol error
|
||||
# in headless/CI environments where popup windows lack a transient parent.
|
||||
w.popup.show = MagicMock()
|
||||
return w
|
||||
|
||||
|
||||
class TestHuggingFaceSearchLineEdit:
|
||||
def test_initial_state(self, widget):
|
||||
assert widget.text() == ""
|
||||
assert widget.placeholderText() != ""
|
||||
|
||||
def test_default_value_set(self, qtbot: QtBot):
|
||||
mock_manager = MagicMock(spec=QNetworkAccessManager)
|
||||
mock_manager.finished = MagicMock()
|
||||
mock_manager.finished.connect = MagicMock()
|
||||
w = HuggingFaceSearchLineEdit(default_value="openai/whisper-tiny", network_access_manager=mock_manager)
|
||||
qtbot.add_widget(w)
|
||||
assert w.text() == "openai/whisper-tiny"
|
||||
|
||||
def test_on_text_edited_emits_model_selected(self, widget, qtbot: QtBot):
|
||||
spy = MagicMock()
|
||||
widget.model_selected.connect(spy)
|
||||
widget.on_text_edited("some/model")
|
||||
spy.assert_called_once_with("some/model")
|
||||
|
||||
def test_fetch_models_skips_short_text(self, widget):
|
||||
widget.setText("ab")
|
||||
result = widget.fetch_models()
|
||||
assert result is None
|
||||
|
||||
def test_fetch_models_makes_request_for_long_text(self, widget):
|
||||
widget.setText("whisper-tiny")
|
||||
mock_reply = MagicMock()
|
||||
widget.network_manager.get = MagicMock(return_value=mock_reply)
|
||||
result = widget.fetch_models()
|
||||
widget.network_manager.get.assert_called_once()
|
||||
assert result == mock_reply
|
||||
|
||||
def test_fetch_models_url_contains_search_text(self, widget):
|
||||
widget.setText("whisper")
|
||||
widget.network_manager.get = MagicMock(return_value=MagicMock())
|
||||
widget.fetch_models()
|
||||
call_args = widget.network_manager.get.call_args[0][0]
|
||||
assert "whisper" in call_args.url().toString()
|
||||
|
||||
def test_on_request_response_network_error_does_not_populate_popup(self, widget):
|
||||
mock_reply = MagicMock(spec=QNetworkReply)
|
||||
mock_reply.error.return_value = QNetworkReply.NetworkError.ConnectionRefusedError
|
||||
widget.on_request_response(mock_reply)
|
||||
assert widget.popup.count() == 0
|
||||
|
||||
def test_on_request_response_populates_popup(self, widget):
|
||||
mock_reply = MagicMock(spec=QNetworkReply)
|
||||
mock_reply.error.return_value = QNetworkReply.NetworkError.NoError
|
||||
models = [{"id": "openai/whisper-tiny"}, {"id": "openai/whisper-base"}]
|
||||
mock_reply.readAll.return_value.data.return_value = json.dumps(models).encode()
|
||||
widget.on_request_response(mock_reply)
|
||||
assert widget.popup.count() == 2
|
||||
assert widget.popup.item(0).text() == "openai/whisper-tiny"
|
||||
assert widget.popup.item(1).text() == "openai/whisper-base"
|
||||
|
||||
def test_on_request_response_empty_models_does_not_show_popup(self, widget):
|
||||
mock_reply = MagicMock(spec=QNetworkReply)
|
||||
mock_reply.error.return_value = QNetworkReply.NetworkError.NoError
|
||||
mock_reply.readAll.return_value.data.return_value = json.dumps([]).encode()
|
||||
widget.on_request_response(mock_reply)
|
||||
assert widget.popup.count() == 0
|
||||
widget.popup.show.assert_not_called()
|
||||
|
||||
def test_on_request_response_item_has_user_role_data(self, widget):
|
||||
mock_reply = MagicMock(spec=QNetworkReply)
|
||||
mock_reply.error.return_value = QNetworkReply.NetworkError.NoError
|
||||
models = [{"id": "facebook/mms-1b-all"}]
|
||||
mock_reply.readAll.return_value.data.return_value = json.dumps(models).encode()
|
||||
widget.on_request_response(mock_reply)
|
||||
item = widget.popup.item(0)
|
||||
assert item.data(Qt.ItemDataRole.UserRole) == "facebook/mms-1b-all"
|
||||
|
||||
def test_on_select_item_emits_model_selected(self, widget, qtbot: QtBot):
|
||||
item = QListWidgetItem("openai/whisper-tiny")
|
||||
item.setData(Qt.ItemDataRole.UserRole, "openai/whisper-tiny")
|
||||
widget.popup.addItem(item)
|
||||
widget.popup.setCurrentItem(item)
|
||||
|
||||
spy = MagicMock()
|
||||
widget.model_selected.connect(spy)
|
||||
widget.on_select_item()
|
||||
|
||||
spy.assert_called_with("openai/whisper-tiny")
|
||||
assert widget.text() == "openai/whisper-tiny"
|
||||
|
||||
def test_on_select_item_hides_popup(self, widget):
|
||||
item = QListWidgetItem("openai/whisper-tiny")
|
||||
item.setData(Qt.ItemDataRole.UserRole, "openai/whisper-tiny")
|
||||
widget.popup.addItem(item)
|
||||
widget.popup.setCurrentItem(item)
|
||||
|
||||
with patch.object(widget.popup, 'hide') as mock_hide:
|
||||
widget.on_select_item()
|
||||
mock_hide.assert_called_once()
|
||||
|
||||
def test_on_popup_selected_stops_timer(self, widget):
|
||||
widget.timer.start()
|
||||
assert widget.timer.isActive()
|
||||
widget.on_popup_selected()
|
||||
assert not widget.timer.isActive()
|
||||
|
||||
def test_event_filter_ignores_non_popup_target(self, widget):
|
||||
other = MagicMock()
|
||||
event = MagicMock()
|
||||
assert widget.eventFilter(other, event) is False
|
||||
|
||||
def test_event_filter_mouse_press_hides_popup(self, widget):
|
||||
event = MagicMock()
|
||||
event.type.return_value = QEvent.Type.MouseButtonPress
|
||||
with patch.object(widget.popup, 'hide') as mock_hide:
|
||||
result = widget.eventFilter(widget.popup, event)
|
||||
assert result is True
|
||||
mock_hide.assert_called_once()
|
||||
|
||||
def test_event_filter_escape_hides_popup(self, widget, qtbot: QtBot):
|
||||
event = QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_Escape, Qt.KeyboardModifier.NoModifier)
|
||||
with patch.object(widget.popup, 'hide') as mock_hide:
|
||||
result = widget.eventFilter(widget.popup, event)
|
||||
assert result is True
|
||||
mock_hide.assert_called_once()
|
||||
|
||||
def test_event_filter_enter_selects_item(self, widget, qtbot: QtBot):
|
||||
item = QListWidgetItem("openai/whisper-tiny")
|
||||
item.setData(Qt.ItemDataRole.UserRole, "openai/whisper-tiny")
|
||||
widget.popup.addItem(item)
|
||||
widget.popup.setCurrentItem(item)
|
||||
|
||||
spy = MagicMock()
|
||||
widget.model_selected.connect(spy)
|
||||
|
||||
event = QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_Return, Qt.KeyboardModifier.NoModifier)
|
||||
result = widget.eventFilter(widget.popup, event)
|
||||
assert result is True
|
||||
spy.assert_called_with("openai/whisper-tiny")
|
||||
|
||||
def test_event_filter_enter_no_item_returns_true(self, widget, qtbot: QtBot):
|
||||
event = QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_Return, Qt.KeyboardModifier.NoModifier)
|
||||
result = widget.eventFilter(widget.popup, event)
|
||||
assert result is True
|
||||
|
||||
def test_event_filter_navigation_keys_return_false(self, widget):
|
||||
for key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home,
|
||||
Qt.Key.Key_End, Qt.Key.Key_PageUp, Qt.Key.Key_PageDown]:
|
||||
event = QKeyEvent(QEvent.Type.KeyPress, key, Qt.KeyboardModifier.NoModifier)
|
||||
assert widget.eventFilter(widget.popup, event) is False
|
||||
|
||||
def test_event_filter_other_key_hides_popup(self, widget):
|
||||
event = QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_A, Qt.KeyboardModifier.NoModifier)
|
||||
with patch.object(widget.popup, 'hide') as mock_hide:
|
||||
widget.eventFilter(widget.popup, event)
|
||||
mock_hide.assert_called_once()
|
||||
|
|
@ -90,6 +90,217 @@ class TestSpeakerIdentificationWidget:
|
|||
assert (result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bien venue dans. '}]]
|
||||
or result == [[{'end_time': 8904, 'speaker': 'Speaker 0', 'start_time': 140, 'text': 'Bienvenue dans. '}]])
|
||||
|
||||
def test_identify_button_toggles_visibility(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
# Before: identify visible, cancel hidden
|
||||
assert not widget.step_1_button.isHidden()
|
||||
assert widget.cancel_button.isHidden()
|
||||
|
||||
from PyQt6.QtCore import QThread as RealQThread
|
||||
mock_thread = MagicMock(spec=RealQThread)
|
||||
mock_thread.started = MagicMock()
|
||||
mock_thread.started.connect = MagicMock()
|
||||
|
||||
with patch.object(widget, '_cleanup_thread'), \
|
||||
patch('buzz.widgets.transcription_viewer.speaker_identification_widget.QThread', return_value=mock_thread), \
|
||||
patch.object(widget, 'worker', create=True):
|
||||
# patch moveToThread on IdentificationWorker to avoid type error
|
||||
with patch.object(IdentificationWorker, 'moveToThread'):
|
||||
widget.on_identify_button_clicked()
|
||||
|
||||
# After: identify hidden, cancel visible
|
||||
assert widget.step_1_button.isHidden()
|
||||
assert not widget.cancel_button.isHidden()
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_cancel_button_resets_ui(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
# Simulate identification started
|
||||
widget.step_1_button.setVisible(False)
|
||||
widget.cancel_button.setVisible(True)
|
||||
|
||||
with patch.object(widget, '_cleanup_thread'):
|
||||
widget.on_cancel_button_clicked()
|
||||
|
||||
assert not widget.step_1_button.isHidden()
|
||||
assert widget.cancel_button.isHidden()
|
||||
assert widget.progress_bar.value() == 0
|
||||
assert len(widget.progress_label.text()) > 0
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_on_progress_update_sets_label_and_bar(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
widget.on_progress_update("3/8 Loading alignment model")
|
||||
|
||||
assert widget.progress_label.text() == "3/8 Loading alignment model"
|
||||
assert widget.progress_bar.value() == 3
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_on_progress_update_step_8_enables_save(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
assert not widget.save_button.isEnabled()
|
||||
|
||||
widget.on_progress_update("8/8 Identification done")
|
||||
|
||||
assert widget.save_button.isEnabled()
|
||||
assert widget.step_2_group_box.isEnabled()
|
||||
assert widget.merge_speaker_sentences.isEnabled()
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_on_identification_finished_empty_result(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
initial_row_count = widget.speaker_preview_row.count()
|
||||
|
||||
widget.on_identification_finished([])
|
||||
|
||||
assert widget.identification_result == []
|
||||
# Empty result returns early — speaker preview row unchanged
|
||||
assert widget.speaker_preview_row.count() == initial_row_count
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_on_identification_finished_populates_speakers(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
result = [
|
||||
{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 3000, 'text': 'Hello world.'},
|
||||
{'speaker': 'Speaker 1', 'start_time': 3000, 'end_time': 6000, 'text': 'Hi there.'},
|
||||
]
|
||||
widget.on_identification_finished(result)
|
||||
|
||||
assert widget.identification_result == result
|
||||
# Two speaker rows should have been created
|
||||
assert widget.speaker_preview_row.count() == 2
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_on_identification_error_resets_buttons(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
widget.step_1_button.setVisible(False)
|
||||
widget.cancel_button.setVisible(True)
|
||||
|
||||
widget.on_identification_error("Some error")
|
||||
|
||||
assert not widget.step_1_button.isHidden()
|
||||
assert widget.cancel_button.isHidden()
|
||||
assert widget.progress_bar.value() == 0
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_on_save_no_merge(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
result = [
|
||||
{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 2000, 'text': 'Hello.'},
|
||||
{'speaker': 'Speaker 0', 'start_time': 2000, 'end_time': 4000, 'text': 'World.'},
|
||||
{'speaker': 'Speaker 1', 'start_time': 4000, 'end_time': 6000, 'text': 'Hi.'},
|
||||
]
|
||||
widget.on_identification_finished(result)
|
||||
widget.merge_speaker_sentences.setChecked(False)
|
||||
|
||||
with patch.object(widget.transcription_service, 'copy_transcription', return_value=uuid.uuid4()) as mock_copy, \
|
||||
patch.object(widget.transcription_service, 'update_transcription_as_completed') as mock_update:
|
||||
widget.on_save_button_clicked()
|
||||
|
||||
mock_copy.assert_called_once()
|
||||
mock_update.assert_called_once()
|
||||
segments = mock_update.call_args[0][1]
|
||||
# No merge: 3 entries → 3 segments
|
||||
assert len(segments) == 3
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_on_save_with_merge(self, qtbot: QtBot, transcription, transcription_service):
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
result = [
|
||||
{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 2000, 'text': 'Hello.'},
|
||||
{'speaker': 'Speaker 0', 'start_time': 2000, 'end_time': 4000, 'text': 'World.'},
|
||||
{'speaker': 'Speaker 1', 'start_time': 4000, 'end_time': 6000, 'text': 'Hi.'},
|
||||
]
|
||||
widget.on_identification_finished(result)
|
||||
widget.merge_speaker_sentences.setChecked(True)
|
||||
|
||||
with patch.object(widget.transcription_service, 'copy_transcription', return_value=uuid.uuid4()), \
|
||||
patch.object(widget.transcription_service, 'update_transcription_as_completed') as mock_update:
|
||||
widget.on_save_button_clicked()
|
||||
|
||||
segments = mock_update.call_args[0][1]
|
||||
# Merge: two consecutive Speaker 0 entries → merged into 1; Speaker 1 → 1 = 2 total
|
||||
assert len(segments) == 2
|
||||
assert "Speaker 0" in segments[0].text
|
||||
assert "Hello." in segments[0].text
|
||||
assert "World." in segments[0].text
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_on_save_emits_transcriptions_updated(self, qtbot: QtBot, transcription, transcription_service):
|
||||
updated_signal = MagicMock()
|
||||
widget = SpeakerIdentificationWidget(
|
||||
transcription=transcription,
|
||||
transcription_service=transcription_service,
|
||||
transcriptions_updated_signal=updated_signal,
|
||||
)
|
||||
qtbot.addWidget(widget)
|
||||
|
||||
result = [{'speaker': 'Speaker 0', 'start_time': 0, 'end_time': 1000, 'text': 'Hi.'}]
|
||||
widget.on_identification_finished(result)
|
||||
|
||||
new_id = uuid.uuid4()
|
||||
with patch.object(widget.transcription_service, 'copy_transcription', return_value=new_id), \
|
||||
patch.object(widget.transcription_service, 'update_transcription_as_completed'):
|
||||
widget.on_save_button_clicked()
|
||||
|
||||
updated_signal.emit.assert_called_once_with(new_id)
|
||||
|
||||
widget.close()
|
||||
|
||||
def test_batch_processing_with_many_words(self):
|
||||
"""Test batch processing when there are more than 200 words."""
|
||||
# Create mock punctuation model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue