mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-03 00:12:14 +02:00
feat: save transcriptions to sqlite (#682)
This commit is contained in:
parent
dfac983f13
commit
ae5af308b2
4
Makefile
4
Makefile
|
@ -28,9 +28,9 @@ clean:
|
|||
rm -f buzz/whisper_cpp.py
|
||||
rm -rf dist/* || true
|
||||
|
||||
COVERAGE_THRESHOLD := 76
|
||||
COVERAGE_THRESHOLD := 77
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
COVERAGE_THRESHOLD := 71
|
||||
COVERAGE_THRESHOLD := 72
|
||||
endif
|
||||
|
||||
test: buzz/whisper_cpp.py translation_mo
|
||||
|
|
11
buzz/buzz.py
11
buzz/buzz.py
|
@ -6,7 +6,7 @@ import platform
|
|||
import sys
|
||||
from typing import TextIO
|
||||
|
||||
from appdirs import user_log_dir
|
||||
from platformdirs import user_log_dir
|
||||
|
||||
# Check for segfaults if not running in frozen mode
|
||||
if getattr(sys, "frozen", False) is False:
|
||||
|
@ -57,7 +57,14 @@ def main():
|
|||
|
||||
from buzz.cli import parse_command_line
|
||||
from buzz.widgets.application import Application
|
||||
from buzz.db.dao.transcription_dao import TranscriptionDAO
|
||||
from buzz.db.dao.transcription_segment_dao import TranscriptionSegmentDAO
|
||||
from buzz.db.service.transcription_service import TranscriptionService
|
||||
from buzz.db.db import setup_app_db
|
||||
|
||||
app = Application()
|
||||
db = setup_app_db()
|
||||
app = Application(
|
||||
TranscriptionService(TranscriptionDAO(db), TranscriptionSegmentDAO(db))
|
||||
)
|
||||
parse_command_line(app)
|
||||
sys.exit(app.exec())
|
||||
|
|
0
buzz/db/__init__.py
Normal file
0
buzz/db/__init__.py
Normal file
0
buzz/db/dao/__init__.py
Normal file
0
buzz/db/dao/__init__.py
Normal file
53
buzz/db/dao/dao.py
Normal file
53
buzz/db/dao/dao.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
# Adapted from https://github.com/zhiyiYo/Groove
|
||||
from abc import ABC
|
||||
from typing import TypeVar, Generic, Any, Type
|
||||
|
||||
from PyQt6.QtSql import QSqlDatabase, QSqlQuery, QSqlRecord
|
||||
|
||||
from buzz.db.entity.entity import Entity
|
||||
|
||||
T = TypeVar("T", bound=Entity)
|
||||
|
||||
|
||||
class DAO(ABC, Generic[T]):
|
||||
entity: Type[T]
|
||||
|
||||
def __init__(self, table: str, db: QSqlDatabase):
|
||||
self.db = db
|
||||
self.table = table
|
||||
|
||||
def insert(self, record: T):
|
||||
query = self._create_query()
|
||||
keys = record.__dict__.keys()
|
||||
query.prepare(
|
||||
f"""
|
||||
INSERT INTO {self.table} ({", ".join(keys)})
|
||||
VALUES ({", ".join([f":{key}" for key in keys])})
|
||||
"""
|
||||
)
|
||||
for key, value in record.__dict__.items():
|
||||
query.bindValue(f":{key}", value)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def find_by_id(self, id: Any) -> T | None:
|
||||
query = self._create_query()
|
||||
query.prepare(f"SELECT * FROM {self.table} WHERE id = :id")
|
||||
query.bindValue(":id", id)
|
||||
return self._execute(query)
|
||||
|
||||
def to_entity(self, record: QSqlRecord) -> T:
|
||||
entity = self.entity()
|
||||
for i in range(record.count()):
|
||||
setattr(entity, record.fieldName(i), record.value(i))
|
||||
return entity
|
||||
|
||||
def _execute(self, query: QSqlQuery) -> T | None:
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
if not query.first():
|
||||
return None
|
||||
return self.to_entity(query.record())
|
||||
|
||||
def _create_query(self):
|
||||
return QSqlQuery(self.db)
|
159
buzz/db/dao/transcription_dao.py
Normal file
159
buzz/db/dao/transcription_dao.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from PyQt6.QtSql import QSqlDatabase
|
||||
|
||||
from buzz.db.dao.dao import DAO
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask
|
||||
|
||||
|
||||
class TranscriptionDAO(DAO[Transcription]):
|
||||
entity = Transcription
|
||||
|
||||
def __init__(self, db: QSqlDatabase):
|
||||
super().__init__("transcription", db)
|
||||
|
||||
def create_transcription(self, task: FileTranscriptionTask):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
INSERT INTO transcription (
|
||||
id,
|
||||
export_formats,
|
||||
file,
|
||||
output_folder,
|
||||
language,
|
||||
model_type,
|
||||
source,
|
||||
status,
|
||||
task,
|
||||
time_queued,
|
||||
url,
|
||||
whisper_model_size
|
||||
) VALUES (
|
||||
:id,
|
||||
:export_formats,
|
||||
:file,
|
||||
:output_folder,
|
||||
:language,
|
||||
:model_type,
|
||||
:source,
|
||||
:status,
|
||||
:task,
|
||||
:time_queued,
|
||||
:url,
|
||||
:whisper_model_size
|
||||
)
|
||||
"""
|
||||
)
|
||||
query.bindValue(":id", str(task.uid))
|
||||
query.bindValue(
|
||||
":export_formats",
|
||||
", ".join(
|
||||
[
|
||||
output_format.value
|
||||
for output_format in task.file_transcription_options.output_formats
|
||||
]
|
||||
),
|
||||
)
|
||||
query.bindValue(":file", task.file_path)
|
||||
query.bindValue(":output_folder", task.output_directory)
|
||||
query.bindValue(":language", task.transcription_options.language)
|
||||
query.bindValue(
|
||||
":model_type", task.transcription_options.model.model_type.value
|
||||
)
|
||||
query.bindValue(":source", task.source.value)
|
||||
query.bindValue(":status", FileTranscriptionTask.Status.QUEUED.value)
|
||||
query.bindValue(":task", task.transcription_options.task.value)
|
||||
query.bindValue(":time_queued", datetime.now().isoformat())
|
||||
query.bindValue(":url", task.url)
|
||||
query.bindValue(
|
||||
":whisper_model_size",
|
||||
task.transcription_options.model.whisper_model_size.value
|
||||
if task.transcription_options.model.whisper_model_size
|
||||
else None,
|
||||
)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def update_transcription_as_started(self, id: UUID):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET status = :status, time_started = :time_started
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":status", FileTranscriptionTask.Status.IN_PROGRESS.value)
|
||||
query.bindValue(":time_started", datetime.now().isoformat())
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def update_transcription_as_failed(self, id: UUID, error: str):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET status = :status, time_ended = :time_ended, error_message = :error_message
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":status", FileTranscriptionTask.Status.FAILED.value)
|
||||
query.bindValue(":time_ended", datetime.now().isoformat())
|
||||
query.bindValue(":error_message", error)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def update_transcription_as_canceled(self, id: UUID):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET status = :status, time_ended = :time_ended
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":status", FileTranscriptionTask.Status.CANCELED.value)
|
||||
query.bindValue(":time_ended", datetime.now().isoformat())
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def update_transcription_progress(self, id: UUID, progress: float):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET status = :status, progress = :progress
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":status", FileTranscriptionTask.Status.IN_PROGRESS.value)
|
||||
query.bindValue(":progress", progress)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def update_transcription_as_completed(self, id: UUID):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET status = :status, time_ended = :time_ended
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":status", FileTranscriptionTask.Status.COMPLETED.value)
|
||||
query.bindValue(":time_ended", datetime.now().isoformat())
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
11
buzz/db/dao/transcription_segment_dao.py
Normal file
11
buzz/db/dao/transcription_segment_dao.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from PyQt6.QtSql import QSqlDatabase
|
||||
|
||||
from buzz.db.dao.dao import DAO
|
||||
from buzz.db.entity.transcription_segment import TranscriptionSegment
|
||||
|
||||
|
||||
class TranscriptionSegmentDAO(DAO[TranscriptionSegment]):
|
||||
entity = TranscriptionSegment
|
||||
|
||||
def __init__(self, db: QSqlDatabase):
|
||||
super().__init__("transcription_segment", db)
|
39
buzz/db/db.py
Normal file
39
buzz/db/db.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
|
||||
from PyQt6.QtSql import QSqlDatabase
|
||||
from platformdirs import user_data_dir
|
||||
|
||||
from buzz.db.helpers import (
|
||||
run_sqlite_migrations,
|
||||
copy_transcriptions_from_json_to_sqlite,
|
||||
mark_in_progress_and_queued_transcriptions_as_canceled,
|
||||
)
|
||||
|
||||
APP_DB_PATH = os.path.join(user_data_dir("Buzz"), "Buzz.sqlite")
|
||||
|
||||
|
||||
def setup_app_db() -> QSqlDatabase:
|
||||
return _setup_db(APP_DB_PATH)
|
||||
|
||||
|
||||
def setup_test_db() -> QSqlDatabase:
|
||||
return _setup_db(tempfile.mktemp())
|
||||
|
||||
|
||||
def _setup_db(path: str) -> QSqlDatabase:
|
||||
# Run migrations
|
||||
db = sqlite3.connect(path)
|
||||
run_sqlite_migrations(db)
|
||||
copy_transcriptions_from_json_to_sqlite(db)
|
||||
mark_in_progress_and_queued_transcriptions_as_canceled(db)
|
||||
db.close()
|
||||
|
||||
db = QSqlDatabase.addDatabase("QSQLITE")
|
||||
db.setDatabaseName(path)
|
||||
if not db.open():
|
||||
raise RuntimeError(f"Failed to open database connection: {db.databaseName()}")
|
||||
logging.debug("Database connection opened: %s", db.databaseName())
|
||||
return db
|
0
buzz/db/entity/__init__.py
Normal file
0
buzz/db/entity/__init__.py
Normal file
12
buzz/db/entity/entity.py
Normal file
12
buzz/db/entity/entity.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from abc import ABC
|
||||
|
||||
from PyQt6.QtSql import QSqlRecord
|
||||
|
||||
|
||||
class Entity(ABC):
|
||||
@classmethod
|
||||
def from_record(cls, record: QSqlRecord):
|
||||
entity = cls()
|
||||
for i in range(record.count()):
|
||||
setattr(entity, record.fieldName(i), record.value(i))
|
||||
return entity
|
54
buzz/db/entity/transcription.py
Normal file
54
buzz/db/entity/transcription.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from buzz.db.entity.entity import Entity
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.transcriber.transcriber import OutputFormat, Task, FileTranscriptionTask
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transcription(Entity):
|
||||
status: str = FileTranscriptionTask.Status.QUEUED.value
|
||||
task: str = Task.TRANSCRIBE.value
|
||||
model_type: str = ModelType.WHISPER.value
|
||||
whisper_model_size: str | None = None
|
||||
language: str | None = None
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
error_message: str | None = None
|
||||
file: str | None = None
|
||||
time_queued: str = datetime.datetime.now().isoformat()
|
||||
|
||||
@property
|
||||
def id_as_uuid(self):
|
||||
return uuid.UUID(hex=self.id)
|
||||
|
||||
@property
|
||||
def status_as_status(self):
|
||||
return FileTranscriptionTask.Status(self.status)
|
||||
|
||||
def get_output_file_path(
|
||||
self,
|
||||
output_format: OutputFormat,
|
||||
output_directory: str | None = None,
|
||||
):
|
||||
input_file_name = os.path.splitext(os.path.basename(self.file))[0]
|
||||
|
||||
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")
|
||||
|
||||
export_file_name_template = Settings().get_default_export_file_template()
|
||||
|
||||
output_file_name = (
|
||||
export_file_name_template.replace("{{ input_file_name }}", input_file_name)
|
||||
.replace("{{ task }}", self.task)
|
||||
.replace("{{ language }}", self.language or "")
|
||||
.replace("{{ model_type }}", self.model_type)
|
||||
.replace("{{ model_size }}", self.whisper_model_size or "")
|
||||
.replace("{{ date_time }}", date_time_now)
|
||||
+ f".{output_format.value}"
|
||||
)
|
||||
|
||||
output_directory = output_directory or os.path.dirname(self.file)
|
||||
return os.path.join(output_directory, output_file_name)
|
11
buzz/db/entity/transcription_segment.py
Normal file
11
buzz/db/entity/transcription_segment.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from buzz.db.entity.entity import Entity
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionSegment(Entity):
|
||||
start_time: int
|
||||
end_time: int
|
||||
text: str
|
||||
transcription_id: str
|
84
buzz/db/helpers.py
Normal file
84
buzz/db/helpers.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
from datetime import datetime
|
||||
from sqlite3 import Connection
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
from buzz.db.migrator import dumb_migrate_db
|
||||
|
||||
|
||||
def copy_transcriptions_from_json_to_sqlite(conn: Connection):
|
||||
cache = TasksCache()
|
||||
if os.path.exists(cache.tasks_list_file_path):
|
||||
tasks = cache.load()
|
||||
cursor = conn.cursor()
|
||||
for task in tasks:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO transcription (id, error_message, export_formats, file, output_folder, progress, language, model_type, source, status, task, time_ended, time_queued, time_started, url, whisper_model_size)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
RETURNING id;
|
||||
""",
|
||||
(
|
||||
str(task.uid),
|
||||
task.error,
|
||||
", ".join(
|
||||
[
|
||||
format.value
|
||||
for format in task.file_transcription_options.output_formats
|
||||
]
|
||||
),
|
||||
task.file_path,
|
||||
task.output_directory,
|
||||
task.fraction_completed,
|
||||
task.transcription_options.language,
|
||||
task.transcription_options.model.model_type.value,
|
||||
task.source.value,
|
||||
task.status.value,
|
||||
task.transcription_options.task.value,
|
||||
task.completed_at,
|
||||
task.queued_at,
|
||||
task.started_at,
|
||||
task.url,
|
||||
task.transcription_options.model.whisper_model_size.value
|
||||
if task.transcription_options.model.whisper_model_size
|
||||
else None,
|
||||
),
|
||||
)
|
||||
transcription_id = cursor.fetchone()[0]
|
||||
|
||||
for segment in task.segments:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO transcription_segment (end_time, start_time, text, transcription_id)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
segment.end,
|
||||
segment.start,
|
||||
segment.text,
|
||||
transcription_id,
|
||||
),
|
||||
)
|
||||
# os.remove(cache.tasks_list_file_path)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def run_sqlite_migrations(db: Connection):
|
||||
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
|
||||
|
||||
with open(schema_path) as schema_file:
|
||||
schema = schema_file.read()
|
||||
dumb_migrate_db(db=db, schema=schema)
|
||||
|
||||
|
||||
def mark_in_progress_and_queued_transcriptions_as_canceled(conn: Connection):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET status = 'canceled', time_ended = ?
|
||||
WHERE status = 'in_progress' OR status = 'queued';
|
||||
""",
|
||||
(datetime.now().isoformat(),),
|
||||
)
|
||||
conn.commit()
|
284
buzz/db/migrator.py
Normal file
284
buzz/db/migrator.py
Normal file
|
@ -0,0 +1,284 @@
|
|||
# coding: utf-8
|
||||
# https://gist.github.com/simonw/664b4b0851c1899dc55e1fb655181037
|
||||
|
||||
"""Simple declarative schema migration for SQLite.
|
||||
See <https://david.rothlis.net/declarative-schema-migration-for-sqlite>.
|
||||
Author: William Manley <will@stb-tester.com>.
|
||||
Copyright © 2019-2022 Stb-tester.com Ltd.
|
||||
License: MIT.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
def dumb_migrate_db(db, schema, allow_deletions=False):
|
||||
"""
|
||||
Migrates a database to the new schema given by the SQL text `schema`
|
||||
preserving the data. We create any table that exists in schema, delete any
|
||||
old table that is no longer used and add/remove columns and indices as
|
||||
necessary.
|
||||
Under this scheme there are a set of changes that we can make to the schema
|
||||
and this script will handle it fine:
|
||||
1. Adding a new table
|
||||
2. Adding, deleting or modifying an index
|
||||
3. Adding a column to an existing table as long as the new column can be
|
||||
NULL or has a DEFAULT value specified.
|
||||
4. Changing a column to remove NULL or DEFAULT as long as all values in the
|
||||
database are not NULL
|
||||
5. Changing the type of a column
|
||||
6. Changing the user_version
|
||||
In addition this function is capable of:
|
||||
1. Deleting tables
|
||||
2. Deleting columns from tables
|
||||
But only if allow_deletions=True. If the new schema requires a column/table
|
||||
to be deleted and allow_deletions=False this function will raise
|
||||
`RuntimeError`.
|
||||
Note: When this function is called a transaction must not be held open on
|
||||
db. A transaction will be used internally. If you wish to perform
|
||||
additional migration steps as part of a migration use DBMigrator directly.
|
||||
Any internally generated rowid columns by SQLite may change values by this
|
||||
migration.
|
||||
"""
|
||||
with DBMigrator(db, schema, allow_deletions) as migrator:
|
||||
migrator.migrate()
|
||||
return bool(migrator.n_changes)
|
||||
|
||||
|
||||
class DBMigrator:
|
||||
def __init__(self, db, schema, allow_deletions=False):
|
||||
self.db = db
|
||||
self.schema = schema
|
||||
self.allow_deletions = allow_deletions
|
||||
|
||||
self.pristine = sqlite3.connect(":memory:")
|
||||
self.pristine.executescript(schema)
|
||||
self.n_changes = 0
|
||||
|
||||
self.orig_foreign_keys = None
|
||||
|
||||
def log_execute(self, msg, sql, args=None):
|
||||
# It's important to log any changes we're making to the database for
|
||||
# forensics later
|
||||
msg_tmpl = "Database migration: %s with SQL:\n%s"
|
||||
msg_argv = (msg, _left_pad(dedent(sql)))
|
||||
if args:
|
||||
msg_tmpl += " args = %r"
|
||||
msg_argv += (args,)
|
||||
else:
|
||||
args = []
|
||||
logging.info(msg_tmpl, *msg_argv)
|
||||
self.db.execute(sql, args)
|
||||
self.n_changes += 1
|
||||
|
||||
def __enter__(self):
|
||||
self.orig_foreign_keys = self.db.execute("PRAGMA foreign_keys").fetchone()[0]
|
||||
if self.orig_foreign_keys:
|
||||
self.log_execute(
|
||||
"Disable foreign keys temporarily for migration",
|
||||
"PRAGMA foreign_keys = OFF",
|
||||
)
|
||||
# This doesn't count as a change because we'll undo it at the end
|
||||
self.n_changes = 0
|
||||
|
||||
self.db.__enter__()
|
||||
self.db.execute("BEGIN")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
self.db.__exit__(exc_type, exc_value, exc_tb)
|
||||
if exc_value is None:
|
||||
# The SQLite docs say:
|
||||
#
|
||||
# > This pragma is a no-op within a transaction; foreign key
|
||||
# > constraint enforcement may only be enabled or disabled when
|
||||
# > there is no pending BEGIN or SAVEPOINT.
|
||||
old_changes = self.n_changes
|
||||
new_val = self._migrate_pragma("foreign_keys")
|
||||
if new_val == self.orig_foreign_keys:
|
||||
self.n_changes = old_changes
|
||||
|
||||
# SQLite docs say:
|
||||
#
|
||||
# > A VACUUM will fail if there is an open transaction on the database
|
||||
# > connection that is attempting to run the VACUUM.
|
||||
if self.n_changes:
|
||||
self.db.execute("VACUUM")
|
||||
else:
|
||||
if self.orig_foreign_keys:
|
||||
self.log_execute(
|
||||
"Re-enable foreign keys after migration", "PRAGMA foreign_keys = ON"
|
||||
)
|
||||
|
||||
def migrate(self):
|
||||
# In CI the database schema may be changing all the time. This checks
|
||||
# the current db and if it doesn't match database.sql we will
|
||||
# modify it so it does match where possible.
|
||||
pristine_tables = dict(
|
||||
self.pristine.execute(
|
||||
"""\
|
||||
SELECT name, sql FROM sqlite_master
|
||||
WHERE type = \"table\" AND name != \"sqlite_sequence\""""
|
||||
).fetchall()
|
||||
)
|
||||
pristine_indices = dict(
|
||||
self.pristine.execute(
|
||||
"""\
|
||||
SELECT name, sql FROM sqlite_master
|
||||
WHERE type = \"index\""""
|
||||
).fetchall()
|
||||
)
|
||||
|
||||
tables = dict(
|
||||
self.db.execute(
|
||||
"""\
|
||||
SELECT name, sql FROM sqlite_master
|
||||
WHERE type = \"table\" AND name != \"sqlite_sequence\""""
|
||||
).fetchall()
|
||||
)
|
||||
|
||||
new_tables = set(pristine_tables.keys()) - set(tables.keys())
|
||||
removed_tables = set(tables.keys()) - set(pristine_tables.keys())
|
||||
if removed_tables and not self.allow_deletions:
|
||||
raise RuntimeError(
|
||||
"Database migration: Refusing to delete tables %r" % removed_tables
|
||||
)
|
||||
|
||||
modified_tables = set(
|
||||
name
|
||||
for name, sql in pristine_tables.items()
|
||||
if normalise_sql(tables.get(name, "")) != normalise_sql(sql)
|
||||
)
|
||||
|
||||
# This PRAGMA is automatically disabled when the db is committed
|
||||
self.db.execute("PRAGMA defer_foreign_keys = TRUE")
|
||||
|
||||
# New and removed tables are easy:
|
||||
for tbl_name in new_tables:
|
||||
self.log_execute("Create table %s" % tbl_name, pristine_tables[tbl_name])
|
||||
for tbl_name in removed_tables:
|
||||
self.log_execute("Drop table %s" % tbl_name, "DROP TABLE %s" % tbl_name)
|
||||
|
||||
for tbl_name in modified_tables:
|
||||
# The SQLite documentation insists that we create the new table and
|
||||
# rename it over the old rather than moving the old out of the way
|
||||
# and then creating the new
|
||||
create_table_sql = pristine_tables[tbl_name]
|
||||
create_table_sql = re.sub(
|
||||
r"\b%s\b" % re.escape(tbl_name),
|
||||
tbl_name + "_migration_new",
|
||||
create_table_sql,
|
||||
)
|
||||
self.log_execute(
|
||||
"Columns change: Create table %s with updated schema" % tbl_name,
|
||||
create_table_sql,
|
||||
)
|
||||
|
||||
cols = set(
|
||||
[x[1] for x in self.db.execute("PRAGMA table_info(%s)" % tbl_name)]
|
||||
)
|
||||
pristine_cols = set(
|
||||
[
|
||||
x[1]
|
||||
for x in self.pristine.execute("PRAGMA table_info(%s)" % tbl_name)
|
||||
]
|
||||
)
|
||||
|
||||
removed_columns = cols - pristine_cols
|
||||
if not self.allow_deletions and removed_columns:
|
||||
logging.warning(
|
||||
"Database migration: Refusing to remove columns %r from "
|
||||
"table %s. Current cols are %r attempting migration to %r",
|
||||
removed_columns,
|
||||
tbl_name,
|
||||
cols,
|
||||
pristine_cols,
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Database migration: Refusing to remove columns %r from "
|
||||
"table %s" % (removed_columns, tbl_name)
|
||||
)
|
||||
|
||||
logging.info("cols: %s, pristine_cols: %s", cols, pristine_cols)
|
||||
self.log_execute(
|
||||
"Migrate data for table %s" % tbl_name,
|
||||
"""\
|
||||
INSERT INTO {tbl_name}_migration_new ({common})
|
||||
SELECT {common} FROM {tbl_name}""".format(
|
||||
tbl_name=tbl_name,
|
||||
common=", ".join(cols.intersection(pristine_cols)),
|
||||
),
|
||||
)
|
||||
|
||||
# Don't need the old table any more
|
||||
self.log_execute(
|
||||
"Drop old table %s now data has been migrated" % tbl_name,
|
||||
"DROP TABLE %s" % tbl_name,
|
||||
)
|
||||
|
||||
self.log_execute(
|
||||
"Columns change: Move new table %s over old" % tbl_name,
|
||||
"ALTER TABLE %s_migration_new RENAME TO %s" % (tbl_name, tbl_name),
|
||||
)
|
||||
|
||||
# Migrate the indices
|
||||
indices = dict(
|
||||
self.db.execute(
|
||||
"""\
|
||||
SELECT name, sql FROM sqlite_master
|
||||
WHERE type = \"index\""""
|
||||
).fetchall()
|
||||
)
|
||||
for name in set(indices.keys()) - set(pristine_indices.keys()):
|
||||
self.log_execute(
|
||||
"Dropping obsolete index %s" % name, "DROP INDEX %s" % name
|
||||
)
|
||||
for name, sql in pristine_indices.items():
|
||||
if name not in indices:
|
||||
self.log_execute("Creating new index %s" % name, sql)
|
||||
elif sql != indices[name]:
|
||||
self.log_execute(
|
||||
"Index %s changed: Dropping old version" % name,
|
||||
"DROP INDEX %s" % name,
|
||||
)
|
||||
self.log_execute(
|
||||
"Index %s changed: Creating updated version in its place" % name,
|
||||
sql,
|
||||
)
|
||||
|
||||
self._migrate_pragma("user_version")
|
||||
|
||||
if self.pristine.execute("PRAGMA foreign_keys").fetchone()[0]:
|
||||
if self.db.execute("PRAGMA foreign_key_check").fetchall():
|
||||
raise RuntimeError("Database migration: Would fail foreign_key_check")
|
||||
|
||||
def _migrate_pragma(self, pragma):
|
||||
pristine_val = self.pristine.execute("PRAGMA %s" % pragma).fetchone()[0]
|
||||
val = self.db.execute("PRAGMA %s" % pragma).fetchone()[0]
|
||||
|
||||
if val != pristine_val:
|
||||
self.log_execute(
|
||||
"Set %s to %i from %i" % (pragma, pristine_val, val),
|
||||
"PRAGMA %s = %i" % (pragma, pristine_val),
|
||||
)
|
||||
|
||||
return pristine_val
|
||||
|
||||
|
||||
def _left_pad(text, indent=" "):
|
||||
"""Maybe I can find a package in pypi for this?"""
|
||||
return "\n".join(indent + line for line in text.split("\n"))
|
||||
|
||||
|
||||
def normalise_sql(sql):
|
||||
# Remove comments:
|
||||
sql = re.sub(r"--[^\n]*\n", "", sql)
|
||||
# Normalise whitespace:
|
||||
sql = re.sub(r"\s+", " ", sql)
|
||||
sql = re.sub(r" *([(),]) *", r"\1", sql)
|
||||
# Remove unnecessary quotes
|
||||
sql = re.sub(r'"(\w+)"', r"\1", sql)
|
||||
|
||||
return sql.strip()
|
28
buzz/db/schema.sql
Normal file
28
buzz/db/schema.sql
Normal file
|
@ -0,0 +1,28 @@
|
|||
CREATE TABLE transcription (
|
||||
id TEXT PRIMARY KEY,
|
||||
error_message TEXT,
|
||||
export_formats TEXT,
|
||||
file TEXT,
|
||||
output_folder TEXT,
|
||||
progress DOUBLE PRECISION DEFAULT 0.0,
|
||||
language TEXT,
|
||||
model_type TEXT,
|
||||
source TEXT,
|
||||
status TEXT,
|
||||
task TEXT,
|
||||
time_ended TIMESTAMP,
|
||||
time_queued TIMESTAMP NOT NULL,
|
||||
time_started TIMESTAMP,
|
||||
url TEXT,
|
||||
whisper_model_size TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE transcription_segment (
|
||||
id INTEGER PRIMARY KEY,
|
||||
end_time INT DEFAULT 0,
|
||||
start_time INT DEFAULT 0,
|
||||
text TEXT NOT NULL,
|
||||
transcription_id TEXT,
|
||||
FOREIGN KEY (transcription_id) REFERENCES transcription(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_transcription_id ON transcription_segment(transcription_id);
|
0
buzz/db/service/__init__.py
Normal file
0
buzz/db/service/__init__.py
Normal file
44
buzz/db/service/transcription_service.py
Normal file
44
buzz/db/service/transcription_service.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from buzz.db.dao.transcription_dao import TranscriptionDAO
|
||||
from buzz.db.dao.transcription_segment_dao import TranscriptionSegmentDAO
|
||||
from buzz.db.entity.transcription_segment import TranscriptionSegment
|
||||
from buzz.transcriber.transcriber import Segment
|
||||
|
||||
|
||||
class TranscriptionService:
|
||||
def __init__(
|
||||
self,
|
||||
transcription_dao: TranscriptionDAO,
|
||||
transcription_segment_dao: TranscriptionSegmentDAO,
|
||||
):
|
||||
self.transcription_dao = transcription_dao
|
||||
self.transcription_segment_dao = transcription_segment_dao
|
||||
|
||||
def create_transcription(self, task):
|
||||
self.transcription_dao.create_transcription(task)
|
||||
|
||||
def update_transcription_as_started(self, id: UUID):
|
||||
self.transcription_dao.update_transcription_as_started(id)
|
||||
|
||||
def update_transcription_as_failed(self, id: UUID, error: str):
|
||||
self.transcription_dao.update_transcription_as_failed(id, error)
|
||||
|
||||
def update_transcription_as_canceled(self, id: UUID):
|
||||
self.transcription_dao.update_transcription_as_canceled(id)
|
||||
|
||||
def update_transcription_progress(self, id: UUID, progress: float):
|
||||
self.transcription_dao.update_transcription_progress(id, progress)
|
||||
|
||||
def update_transcription_as_completed(self, id: UUID, segments: List[Segment]):
|
||||
self.transcription_dao.update_transcription_as_completed(id)
|
||||
for segment in segments:
|
||||
self.transcription_segment_dao.insert(
|
||||
TranscriptionSegment(
|
||||
start_time=segment.start,
|
||||
end_time=segment.end,
|
||||
text=segment.text,
|
||||
transcription_id=str(id),
|
||||
)
|
||||
)
|
|
@ -1,8 +1,8 @@
|
|||
import datetime
|
||||
import logging
|
||||
import multiprocessing
|
||||
import queue
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import Optional, Tuple, List, Set
|
||||
from uuid import UUID
|
||||
|
||||
from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot
|
||||
|
||||
|
@ -21,13 +21,19 @@ class FileTranscriberQueueWorker(QObject):
|
|||
current_task: Optional[FileTranscriptionTask] = None
|
||||
current_transcriber: Optional[FileTranscriber] = None
|
||||
current_transcriber_thread: Optional[QThread] = None
|
||||
task_updated = pyqtSignal(FileTranscriptionTask)
|
||||
|
||||
task_started = pyqtSignal(FileTranscriptionTask)
|
||||
task_progress = pyqtSignal(FileTranscriptionTask, float)
|
||||
task_download_progress = pyqtSignal(FileTranscriptionTask, float)
|
||||
task_completed = pyqtSignal(FileTranscriptionTask, list)
|
||||
task_error = pyqtSignal(FileTranscriptionTask, str)
|
||||
|
||||
completed = pyqtSignal()
|
||||
|
||||
def __init__(self, parent: Optional[QObject] = None):
|
||||
super().__init__(parent)
|
||||
self.tasks_queue = queue.Queue()
|
||||
self.canceled_tasks = set()
|
||||
self.canceled_tasks: Set[UUID] = set()
|
||||
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
|
@ -42,7 +48,7 @@ class FileTranscriberQueueWorker(QObject):
|
|||
self.completed.emit()
|
||||
return
|
||||
|
||||
if self.current_task.id in self.canceled_tasks:
|
||||
if self.current_task.uid in self.canceled_tasks:
|
||||
continue
|
||||
|
||||
break
|
||||
|
@ -91,53 +97,41 @@ class FileTranscriberQueueWorker(QObject):
|
|||
self.current_transcriber.error.connect(self.run)
|
||||
self.current_transcriber.completed.connect(self.run)
|
||||
|
||||
self.current_task.started_at = datetime.datetime.now()
|
||||
self.task_started.emit(self.current_task)
|
||||
self.current_transcriber_thread.start()
|
||||
|
||||
def add_task(self, task: FileTranscriptionTask):
|
||||
if task.queued_at is None:
|
||||
task.queued_at = datetime.datetime.now()
|
||||
|
||||
self.tasks_queue.put(task)
|
||||
task.status = FileTranscriptionTask.Status.QUEUED
|
||||
self.task_updated.emit(task)
|
||||
|
||||
def cancel_task(self, task_id: int):
|
||||
def cancel_task(self, task_id: UUID):
|
||||
self.canceled_tasks.add(task_id)
|
||||
|
||||
if self.current_task.id == task_id:
|
||||
if self.current_task.uid == task_id:
|
||||
if self.current_transcriber is not None:
|
||||
self.current_transcriber.stop()
|
||||
|
||||
def on_task_error(self, error: str):
|
||||
if (
|
||||
self.current_task is not None
|
||||
and self.current_task.id not in self.canceled_tasks
|
||||
and self.current_task.uid not in self.canceled_tasks
|
||||
):
|
||||
self.current_task.status = FileTranscriptionTask.Status.FAILED
|
||||
self.current_task.error = error
|
||||
self.task_updated.emit(self.current_task)
|
||||
self.task_error.emit(self.current_task, error)
|
||||
|
||||
@pyqtSlot(tuple)
|
||||
def on_task_progress(self, progress: Tuple[int, int]):
|
||||
if self.current_task is not None:
|
||||
self.current_task.status = FileTranscriptionTask.Status.IN_PROGRESS
|
||||
self.current_task.fraction_completed = progress[0] / progress[1]
|
||||
self.task_updated.emit(self.current_task)
|
||||
self.task_progress.emit(self.current_task, progress[0] / progress[1])
|
||||
|
||||
def on_task_download_progress(self, fraction_downloaded: float):
|
||||
if self.current_task is not None:
|
||||
self.current_task.status = FileTranscriptionTask.Status.IN_PROGRESS
|
||||
self.current_task.fraction_downloaded = fraction_downloaded
|
||||
self.task_updated.emit(self.current_task)
|
||||
self.task_download_progress.emit(self.current_task, fraction_downloaded)
|
||||
|
||||
@pyqtSlot(list)
|
||||
def on_task_completed(self, segments: List[Segment]):
|
||||
if self.current_task is not None:
|
||||
self.current_task.status = FileTranscriptionTask.Status.COMPLETED
|
||||
self.current_task.segments = segments
|
||||
self.current_task.completed_at = datetime.datetime.now()
|
||||
self.task_updated.emit(self.current_task)
|
||||
self.task_completed.emit(self.current_task, segments)
|
||||
|
||||
def stop(self):
|
||||
self.tasks_queue.put(None)
|
||||
|
|
|
@ -68,7 +68,12 @@ class FileTranscriber(QObject):
|
|||
output_format
|
||||
) in self.transcription_task.file_transcription_options.output_formats:
|
||||
default_path = get_output_file_path(
|
||||
task=self.transcription_task, output_format=output_format
|
||||
file_path=self.transcription_task.file_path,
|
||||
output_format=output_format,
|
||||
language=self.transcription_task.transcription_options.language,
|
||||
output_directory=self.transcription_task.output_directory,
|
||||
model=self.transcription_task.transcription_options.model,
|
||||
task=self.transcription_task.transcription_options.task,
|
||||
)
|
||||
|
||||
write_output(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import enum
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from random import randint
|
||||
from typing import List, Optional, Tuple, Set
|
||||
|
@ -174,7 +175,9 @@ class FileTranscriptionTask:
|
|||
transcription_options: TranscriptionOptions
|
||||
file_transcription_options: FileTranscriptionOptions
|
||||
model_path: str
|
||||
# deprecated: use uid
|
||||
id: int = field(default_factory=lambda: randint(0, 100_000_000))
|
||||
uid: uuid.UUID = field(default_factory=uuid.uuid4)
|
||||
segments: List[Segment] = field(default_factory=list)
|
||||
status: Optional[Status] = None
|
||||
fraction_completed = 0.0
|
||||
|
@ -188,38 +191,6 @@ class FileTranscriptionTask:
|
|||
url: Optional[str] = None
|
||||
fraction_downloaded: float = 0.0
|
||||
|
||||
def status_text(self) -> str:
|
||||
match self.status:
|
||||
case FileTranscriptionTask.Status.IN_PROGRESS:
|
||||
if self.fraction_downloaded > 0 and self.fraction_completed == 0:
|
||||
return f'{_("Downloading")} ({self.fraction_downloaded :.0%})'
|
||||
return f'{_("In Progress")} ({self.fraction_completed :.0%})'
|
||||
case FileTranscriptionTask.Status.COMPLETED:
|
||||
status = _("Completed")
|
||||
if self.started_at is not None and self.completed_at is not None:
|
||||
status += f" ({self.format_timedelta(self.completed_at - self.started_at)})"
|
||||
return status
|
||||
case FileTranscriptionTask.Status.FAILED:
|
||||
return f'{_("Failed")} ({self.error})'
|
||||
case FileTranscriptionTask.Status.CANCELED:
|
||||
return _("Canceled")
|
||||
case FileTranscriptionTask.Status.QUEUED:
|
||||
return _("Queued")
|
||||
case _:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def format_timedelta(delta: datetime.timedelta):
|
||||
mm, ss = divmod(delta.seconds, 60)
|
||||
result = f"{ss}s"
|
||||
if mm == 0:
|
||||
return result
|
||||
hh, mm = divmod(mm, 60)
|
||||
result = f"{mm}m {result}"
|
||||
if hh == 0:
|
||||
return result
|
||||
return f"{hh}h {result}"
|
||||
|
||||
|
||||
class OutputFormat(enum.Enum):
|
||||
TXT = "txt"
|
||||
|
@ -236,11 +207,15 @@ Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)"
|
|||
|
||||
|
||||
def get_output_file_path(
|
||||
task: FileTranscriptionTask,
|
||||
file_path: str,
|
||||
task: Task,
|
||||
language: Optional[str],
|
||||
model: TranscriptionModel,
|
||||
output_format: OutputFormat,
|
||||
output_directory: str | None = None,
|
||||
export_file_name_template: str | None = None,
|
||||
):
|
||||
input_file_name = os.path.splitext(os.path.basename(task.file_path))[0]
|
||||
input_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")
|
||||
|
||||
export_file_name_template = (
|
||||
|
@ -251,18 +226,18 @@ def get_output_file_path(
|
|||
|
||||
output_file_name = (
|
||||
export_file_name_template.replace("{{ input_file_name }}", input_file_name)
|
||||
.replace("{{ task }}", task.transcription_options.task.value)
|
||||
.replace("{{ language }}", task.transcription_options.language or "")
|
||||
.replace("{{ model_type }}", task.transcription_options.model.model_type.value)
|
||||
.replace("{{ task }}", task.value)
|
||||
.replace("{{ language }}", language or "")
|
||||
.replace("{{ model_type }}", model.model_type.value)
|
||||
.replace(
|
||||
"{{ model_size }}",
|
||||
task.transcription_options.model.whisper_model_size.value
|
||||
if task.transcription_options.model.whisper_model_size is not None
|
||||
model.whisper_model_size.value
|
||||
if model.whisper_model_size is not None
|
||||
else "",
|
||||
)
|
||||
.replace("{{ date_time }}", date_time_now)
|
||||
+ f".{output_format.value}"
|
||||
)
|
||||
|
||||
output_directory = task.output_directory or os.path.dirname(task.file_path)
|
||||
output_directory = output_directory or os.path.dirname(file_path)
|
||||
return os.path.join(output_directory, output_file_name)
|
||||
|
|
|
@ -3,6 +3,7 @@ import sys
|
|||
from PyQt6.QtWidgets import QApplication
|
||||
|
||||
from buzz.__version__ import VERSION
|
||||
from buzz.db.service.transcription_service import TranscriptionService
|
||||
from buzz.settings.settings import APP_NAME
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask
|
||||
from buzz.widgets.main_window import MainWindow
|
||||
|
@ -11,7 +12,7 @@ from buzz.widgets.main_window import MainWindow
|
|||
class Application(QApplication):
|
||||
window: MainWindow
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, transcription_service: TranscriptionService) -> None:
|
||||
super().__init__(sys.argv)
|
||||
|
||||
self.setApplicationName(APP_NAME)
|
||||
|
@ -20,7 +21,7 @@ class Application(QApplication):
|
|||
if sys.platform == "darwin":
|
||||
self.setStyle("Fusion")
|
||||
|
||||
self.window = MainWindow()
|
||||
self.window = MainWindow(transcription_service)
|
||||
self.window.show()
|
||||
|
||||
def add_task(self, task: FileTranscriptionTask):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Dict, Tuple, List, Optional
|
||||
import logging
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (
|
||||
|
@ -13,7 +14,8 @@ from PyQt6.QtWidgets import (
|
|||
QFileDialog,
|
||||
)
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
from buzz.db.service.transcription_service import TranscriptionService
|
||||
from buzz.file_transcriber_queue_worker import FileTranscriberQueueWorker
|
||||
from buzz.locale import _
|
||||
from buzz.settings.settings import APP_NAME, Settings
|
||||
|
@ -24,6 +26,7 @@ from buzz.transcriber.transcriber import (
|
|||
TranscriptionOptions,
|
||||
FileTranscriptionOptions,
|
||||
SUPPORTED_AUDIO_FORMATS,
|
||||
Segment,
|
||||
)
|
||||
from buzz.widgets.icon import BUZZ_ICON_PATH
|
||||
from buzz.widgets.import_url_dialog import ImportURLDialog
|
||||
|
@ -34,7 +37,9 @@ from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidg
|
|||
from buzz.widgets.transcription_task_folder_watcher import (
|
||||
TranscriptionTaskFolderWatcher,
|
||||
)
|
||||
from buzz.widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
|
||||
from buzz.widgets.transcription_tasks_table_widget import (
|
||||
TranscriptionTasksTableWidget,
|
||||
)
|
||||
from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
|
||||
TranscriptionViewerWidget,
|
||||
)
|
||||
|
@ -42,9 +47,8 @@ from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
|
|||
|
||||
class MainWindow(QMainWindow):
|
||||
table_widget: TranscriptionTasksTableWidget
|
||||
tasks: Dict[int, "FileTranscriptionTask"]
|
||||
|
||||
def __init__(self, tasks_cache=TasksCache()):
|
||||
def __init__(self, transcription_service: TranscriptionService):
|
||||
super().__init__(flags=Qt.WindowType.Window)
|
||||
|
||||
self.setWindowTitle(APP_NAME)
|
||||
|
@ -53,14 +57,12 @@ class MainWindow(QMainWindow):
|
|||
|
||||
self.setAcceptDrops(True)
|
||||
|
||||
self.tasks_cache = tasks_cache
|
||||
|
||||
self.settings = Settings()
|
||||
|
||||
self.shortcut_settings = ShortcutSettings(settings=self.settings)
|
||||
self.shortcuts = self.shortcut_settings.load()
|
||||
|
||||
self.tasks = {}
|
||||
self.transcription_service = transcription_service
|
||||
|
||||
self.toolbar = MainWindowToolbar(shortcuts=self.shortcuts, parent=self)
|
||||
self.toolbar.new_transcription_action_triggered.connect(
|
||||
|
@ -100,7 +102,9 @@ class MainWindow(QMainWindow):
|
|||
self.table_widget = TranscriptionTasksTableWidget(self)
|
||||
self.table_widget.doubleClicked.connect(self.on_table_double_clicked)
|
||||
self.table_widget.return_clicked.connect(self.open_transcript_viewer)
|
||||
self.table_widget.itemSelectionChanged.connect(self.on_table_selection_changed)
|
||||
self.table_widget.selectionModel().selectionChanged.connect(
|
||||
self.on_table_selection_changed
|
||||
)
|
||||
|
||||
self.setCentralWidget(self.table_widget)
|
||||
|
||||
|
@ -110,19 +114,24 @@ class MainWindow(QMainWindow):
|
|||
self.transcriber_worker = FileTranscriberQueueWorker()
|
||||
self.transcriber_worker.moveToThread(self.transcriber_thread)
|
||||
|
||||
self.transcriber_worker.task_updated.connect(self.update_task_table_row)
|
||||
self.transcriber_worker.task_started.connect(self.on_task_started)
|
||||
self.transcriber_worker.task_progress.connect(self.on_task_progress)
|
||||
self.transcriber_worker.task_download_progress.connect(
|
||||
self.on_task_download_progress
|
||||
)
|
||||
self.transcriber_worker.task_error.connect(self.on_task_error)
|
||||
self.transcriber_worker.task_completed.connect(self.on_task_completed)
|
||||
|
||||
self.transcriber_worker.completed.connect(self.transcriber_thread.quit)
|
||||
|
||||
self.transcriber_thread.started.connect(self.transcriber_worker.run)
|
||||
|
||||
self.transcriber_thread.start()
|
||||
|
||||
self.load_tasks_from_cache()
|
||||
|
||||
self.load_geometry()
|
||||
|
||||
self.folder_watcher = TranscriptionTaskFolderWatcher(
|
||||
tasks=self.tasks,
|
||||
tasks={},
|
||||
preferences=self.preferences.folder_watch,
|
||||
)
|
||||
self.folder_watcher.task_found.connect(self.add_task)
|
||||
|
@ -181,21 +190,6 @@ class MainWindow(QMainWindow):
|
|||
)
|
||||
self.add_task(task)
|
||||
|
||||
def upsert_task_in_table(self, task: FileTranscriptionTask):
|
||||
self.table_widget.upsert_task(task)
|
||||
self.tasks[task.id] = task
|
||||
|
||||
def update_task_table_row(self, task: FileTranscriptionTask):
|
||||
self.upsert_task_in_table(task=task)
|
||||
self.on_tasks_changed()
|
||||
|
||||
@staticmethod
|
||||
def task_completed_or_errored(task: FileTranscriptionTask):
|
||||
return (
|
||||
task.status == FileTranscriptionTask.Status.COMPLETED
|
||||
or task.status == FileTranscriptionTask.Status.FAILED
|
||||
)
|
||||
|
||||
def on_clear_history_action_triggered(self):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
if len(selected_rows) == 0:
|
||||
|
@ -210,25 +204,18 @@ class MainWindow(QMainWindow):
|
|||
),
|
||||
)
|
||||
if reply == QMessageBox.StandardButton.Yes:
|
||||
task_ids = [
|
||||
TranscriptionTasksTableWidget.find_task_id(selected_row)
|
||||
for selected_row in selected_rows
|
||||
]
|
||||
for task_id in task_ids:
|
||||
self.table_widget.clear_task(task_id)
|
||||
self.tasks.pop(task_id)
|
||||
self.on_tasks_changed()
|
||||
self.table_widget.delete_transcriptions(selected_rows)
|
||||
|
||||
def on_stop_transcription_action_triggered(self):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
for selected_row in selected_rows:
|
||||
task_id = TranscriptionTasksTableWidget.find_task_id(selected_row)
|
||||
task = self.tasks[task_id]
|
||||
|
||||
task.status = FileTranscriptionTask.Status.CANCELED
|
||||
self.on_tasks_changed()
|
||||
self.transcriber_worker.cancel_task(task_id)
|
||||
self.table_widget.upsert_task(task)
|
||||
selected_transcriptions = self.table_widget.selected_transcriptions()
|
||||
for transcription in selected_transcriptions:
|
||||
transcription_id = transcription.id_as_uuid
|
||||
self.transcriber_worker.cancel_task(transcription_id)
|
||||
self.transcription_service.update_transcription_as_canceled(
|
||||
transcription_id
|
||||
)
|
||||
self.table_widget.refresh_row(transcription_id)
|
||||
self.on_table_selection_changed()
|
||||
|
||||
def on_new_transcription_action_triggered(self):
|
||||
(file_paths, __) = QFileDialog.getOpenFileNames(
|
||||
|
@ -266,8 +253,8 @@ class MainWindow(QMainWindow):
|
|||
def open_transcript_viewer(self):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
for selected_row in selected_rows:
|
||||
task_id = TranscriptionTasksTableWidget.find_task_id(selected_row)
|
||||
self.open_transcription_viewer(task_id)
|
||||
transcription = self.table_widget.transcription(selected_row)
|
||||
self.open_transcription_viewer(transcription)
|
||||
|
||||
def on_table_selection_changed(self):
|
||||
self.toolbar.set_open_transcript_action_enabled(
|
||||
|
@ -281,7 +268,20 @@ class MainWindow(QMainWindow):
|
|||
)
|
||||
|
||||
def should_enable_open_transcript_action(self):
|
||||
return self.selected_tasks_have_status([FileTranscriptionTask.Status.COMPLETED])
|
||||
selected_transcriptions = self.table_widget.selected_transcriptions()
|
||||
if len(selected_transcriptions) == 0:
|
||||
return False
|
||||
return all(
|
||||
MainWindow.can_open_transcript(transcription)
|
||||
for transcription in selected_transcriptions
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def can_open_transcript(transcription: Transcription) -> bool:
|
||||
return (
|
||||
FileTranscriptionTask.Status(transcription.status)
|
||||
== FileTranscriptionTask.Status.COMPLETED
|
||||
)
|
||||
|
||||
def should_enable_stop_transcription_action(self):
|
||||
return self.selected_tasks_have_status(
|
||||
|
@ -301,63 +301,56 @@ class MainWindow(QMainWindow):
|
|||
)
|
||||
|
||||
def selected_tasks_have_status(self, statuses: List[FileTranscriptionTask.Status]):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
if len(selected_rows) == 0:
|
||||
transcriptions = self.table_widget.selected_transcriptions()
|
||||
if len(transcriptions) == 0:
|
||||
return False
|
||||
|
||||
return all(
|
||||
[
|
||||
self.tasks[
|
||||
TranscriptionTasksTableWidget.find_task_id(selected_row)
|
||||
].status
|
||||
in statuses
|
||||
for selected_row in selected_rows
|
||||
transcription.status_as_status in statuses
|
||||
for transcription in transcriptions
|
||||
]
|
||||
)
|
||||
|
||||
def on_table_double_clicked(self, index: QModelIndex):
|
||||
task_id = TranscriptionTasksTableWidget.find_task_id(index)
|
||||
self.open_transcription_viewer(task_id)
|
||||
|
||||
def open_transcription_viewer(self, task_id: int):
|
||||
task = self.tasks[task_id]
|
||||
if task.status != FileTranscriptionTask.Status.COMPLETED:
|
||||
transcription = self.table_widget.transcription(index)
|
||||
if not MainWindow.can_open_transcript(transcription):
|
||||
return
|
||||
self.open_transcription_viewer(transcription)
|
||||
|
||||
def open_transcription_viewer(self, transcription: Transcription):
|
||||
transcription_viewer_widget = TranscriptionViewerWidget(
|
||||
transcription_task=task, parent=self, flags=Qt.WindowType.Window
|
||||
transcription=transcription, parent=self, flags=Qt.WindowType.Window
|
||||
)
|
||||
transcription_viewer_widget.task_changed.connect(self.on_tasks_changed)
|
||||
transcription_viewer_widget.show()
|
||||
|
||||
def add_task(self, task: FileTranscriptionTask):
|
||||
self.transcription_service.create_transcription(task)
|
||||
self.table_widget.refresh_all()
|
||||
self.transcriber_worker.add_task(task)
|
||||
|
||||
def load_tasks_from_cache(self):
|
||||
tasks = self.tasks_cache.load()
|
||||
for task in tasks:
|
||||
if (
|
||||
task.status == FileTranscriptionTask.Status.QUEUED
|
||||
or task.status == FileTranscriptionTask.Status.IN_PROGRESS
|
||||
):
|
||||
task.status = None
|
||||
self.add_task(task)
|
||||
else:
|
||||
self.upsert_task_in_table(task=task)
|
||||
def on_task_started(self, task: FileTranscriptionTask):
|
||||
self.transcription_service.update_transcription_as_started(task.uid)
|
||||
self.table_widget.refresh_row(task.uid)
|
||||
|
||||
def save_tasks_to_cache(self):
|
||||
self.tasks_cache.save(list(self.tasks.values()))
|
||||
def on_task_progress(self, task: FileTranscriptionTask, progress: float):
|
||||
self.transcription_service.update_transcription_progress(task.uid, progress)
|
||||
self.table_widget.refresh_row(task.uid)
|
||||
|
||||
def on_tasks_changed(self):
|
||||
self.toolbar.set_open_transcript_action_enabled(
|
||||
self.should_enable_open_transcript_action()
|
||||
)
|
||||
self.toolbar.set_stop_transcription_action_enabled(
|
||||
self.should_enable_stop_transcription_action()
|
||||
)
|
||||
self.toolbar.set_clear_history_action_enabled(
|
||||
self.should_enable_clear_history_action()
|
||||
)
|
||||
self.save_tasks_to_cache()
|
||||
def on_task_download_progress(
|
||||
self, task: FileTranscriptionTask, fraction_downloaded: float
|
||||
):
|
||||
# TODO: Save download progress in the database
|
||||
pass
|
||||
|
||||
def on_task_completed(self, task: FileTranscriptionTask, segments: List[Segment]):
|
||||
self.transcription_service.update_transcription_as_completed(task.uid, segments)
|
||||
self.table_widget.refresh_row(task.uid)
|
||||
|
||||
def on_task_error(self, task: FileTranscriptionTask, error: str):
|
||||
logging.debug("FAILED!!!!")
|
||||
self.transcription_service.update_transcription_as_failed(task.uid, error)
|
||||
self.table_widget.refresh_row(task.uid)
|
||||
|
||||
def on_shortcuts_changed(self, shortcuts: dict):
|
||||
self.shortcuts = shortcuts
|
||||
|
@ -374,7 +367,6 @@ class MainWindow(QMainWindow):
|
|||
self.transcriber_worker.stop()
|
||||
self.transcriber_thread.quit()
|
||||
self.transcriber_thread.wait()
|
||||
self.save_tasks_to_cache()
|
||||
self.shortcut_settings.save(shortcuts=self.shortcuts)
|
||||
super().closeEvent(event)
|
||||
|
||||
|
|
15
buzz/widgets/record_delegate.py
Normal file
15
buzz/widgets/record_delegate.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from typing import Callable
|
||||
|
||||
from PyQt6.QtSql import QSqlRecord, QSqlTableModel
|
||||
from PyQt6.QtWidgets import QStyledItemDelegate
|
||||
|
||||
|
||||
class RecordDelegate(QStyledItemDelegate):
|
||||
def __init__(self, text_getter: Callable[[QSqlRecord], str]):
|
||||
super().__init__()
|
||||
self.callback = text_getter
|
||||
|
||||
def initStyleOption(self, option, index):
|
||||
super().initStyleOption(option, index)
|
||||
model: QSqlTableModel = index.model()
|
||||
option.text = self.callback(model.record(index.row()))
|
25
buzz/widgets/transcription_record.py
Normal file
25
buzz/widgets/transcription_record.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from uuid import UUID
|
||||
|
||||
from PyQt6.QtSql import QSqlRecord
|
||||
|
||||
from buzz.model_loader import TranscriptionModel, ModelType, WhisperModelSize
|
||||
from buzz.transcriber.transcriber import Task
|
||||
|
||||
|
||||
class TranscriptionRecord:
|
||||
@staticmethod
|
||||
def id(record: QSqlRecord) -> UUID:
|
||||
return UUID(hex=record.value("id"))
|
||||
|
||||
@staticmethod
|
||||
def model(record: QSqlRecord) -> TranscriptionModel:
|
||||
return TranscriptionModel(
|
||||
model_type=ModelType(record.value("model_type")),
|
||||
whisper_model_size=WhisperModelSize(record.value("whisper_model_size"))
|
||||
if record.value("whisper_model_size")
|
||||
else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def task(record: QSqlRecord) -> Task:
|
||||
return Task(record.value("task"))
|
|
@ -15,6 +15,7 @@ class TranscriptionTaskFolderWatcher(QFileSystemWatcher):
|
|||
preferences: FolderWatchPreferences
|
||||
task_found = pyqtSignal(FileTranscriptionTask)
|
||||
|
||||
# TODO: query db instead of passing tasks
|
||||
def __init__(
|
||||
self,
|
||||
tasks: Dict[int, FileTranscriptionTask],
|
||||
|
|
|
@ -1,147 +1,203 @@
|
|||
import enum
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import auto
|
||||
from typing import Optional, Callable
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex
|
||||
from PyQt6.QtCore import Qt
|
||||
from PyQt6.QtCore import pyqtSignal, QModelIndex
|
||||
from PyQt6.QtSql import QSqlTableModel, QSqlRecord
|
||||
from PyQt6.QtWidgets import (
|
||||
QTableWidget,
|
||||
QWidget,
|
||||
QAbstractItemView,
|
||||
QTableWidgetItem,
|
||||
QMenu,
|
||||
QTableView,
|
||||
QAbstractItemView,
|
||||
QStyledItemDelegate,
|
||||
)
|
||||
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
from buzz.locale import _
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, humanize_language
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask
|
||||
from buzz.widgets.record_delegate import RecordDelegate
|
||||
from buzz.widgets.transcription_record import TranscriptionRecord
|
||||
|
||||
|
||||
class Column(enum.Enum):
|
||||
ID = 0
|
||||
ERROR_MESSAGE = auto()
|
||||
EXPORT_FORMATS = auto()
|
||||
FILE = auto()
|
||||
OUTPUT_FOLDER = auto()
|
||||
PROGRESS = auto()
|
||||
LANGUAGE = auto()
|
||||
MODEL_TYPE = auto()
|
||||
SOURCE = auto()
|
||||
STATUS = auto()
|
||||
TASK = auto()
|
||||
TIME_ENDED = auto()
|
||||
TIME_QUEUED = auto()
|
||||
TIME_STARTED = auto()
|
||||
URL = auto()
|
||||
WHISPER_MODEL_SIZE = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableColDef:
|
||||
class ColDef:
|
||||
id: str
|
||||
header: str
|
||||
column_index: int
|
||||
value_getter: Callable[[FileTranscriptionTask], str]
|
||||
column: Column
|
||||
width: Optional[int] = None
|
||||
hidden: bool = False
|
||||
delegate: Optional[QStyledItemDelegate] = None
|
||||
hidden_toggleable: bool = True
|
||||
|
||||
|
||||
class TranscriptionTasksTableWidget(QTableWidget):
|
||||
class Column(enum.Enum):
|
||||
TASK_ID = 0
|
||||
FILE_NAME = auto()
|
||||
MODEL = auto()
|
||||
TASK = auto()
|
||||
STATUS = auto()
|
||||
DATE_ADDED = auto()
|
||||
DATE_COMPLETED = auto()
|
||||
def format_record_status_text(record: QSqlRecord) -> str:
|
||||
status = FileTranscriptionTask.Status(record.value("status"))
|
||||
match status:
|
||||
case FileTranscriptionTask.Status.IN_PROGRESS:
|
||||
return f'{_("In Progress")} ({record.value("progress") :.0%})'
|
||||
case FileTranscriptionTask.Status.COMPLETED:
|
||||
status = _("Completed")
|
||||
started_at = record.value("time_started")
|
||||
completed_at = record.value("time_ended")
|
||||
if started_at != "" and completed_at != "":
|
||||
status += f" ({TranscriptionTasksTableWidget.format_timedelta(datetime.fromisoformat(completed_at) - datetime.fromisoformat(started_at))})"
|
||||
return status
|
||||
case FileTranscriptionTask.Status.FAILED:
|
||||
return f'{_("Failed")} ({record.value("error_message")})'
|
||||
case FileTranscriptionTask.Status.CANCELED:
|
||||
return _("Canceled")
|
||||
case FileTranscriptionTask.Status.QUEUED:
|
||||
return _("Queued")
|
||||
case _:
|
||||
return ""
|
||||
|
||||
|
||||
column_definitions = [
|
||||
ColDef(
|
||||
id="file_name",
|
||||
header="File Name / URL",
|
||||
column=Column.FILE,
|
||||
width=400,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: record.value("url")
|
||||
if record.value("url") != ""
|
||||
else os.path.basename(record.value("file"))
|
||||
),
|
||||
hidden_toggleable=False,
|
||||
),
|
||||
ColDef(
|
||||
id="model",
|
||||
header="Model",
|
||||
column=Column.MODEL_TYPE,
|
||||
width=180,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: str(TranscriptionRecord.model(record))
|
||||
),
|
||||
),
|
||||
ColDef(
|
||||
id="task",
|
||||
header="Task",
|
||||
column=Column.SOURCE,
|
||||
width=120,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: record.value("task").capitalize()
|
||||
),
|
||||
),
|
||||
ColDef(
|
||||
id="status",
|
||||
header="Status",
|
||||
column=Column.STATUS,
|
||||
width=180,
|
||||
delegate=RecordDelegate(text_getter=format_record_status_text),
|
||||
hidden_toggleable=False,
|
||||
),
|
||||
ColDef(
|
||||
id="date_added",
|
||||
header="Date Added",
|
||||
column=Column.TIME_QUEUED,
|
||||
width=180,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: datetime.fromisoformat(
|
||||
record.value("time_queued")
|
||||
).strftime("%Y-%m-%d %H:%M:%S")
|
||||
),
|
||||
),
|
||||
ColDef(
|
||||
id="date_completed",
|
||||
header="Date Completed",
|
||||
column=Column.TIME_ENDED,
|
||||
width=180,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: datetime.fromisoformat(
|
||||
record.value("time_ended")
|
||||
).strftime("%Y-%m-%d %H:%M:%S")
|
||||
if record.value("time_ended") != ""
|
||||
else ""
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TranscriptionTasksTableWidget(QTableView):
|
||||
return_clicked = pyqtSignal()
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.setRowCount(0)
|
||||
self.setAlternatingRowColors(True)
|
||||
self._model = QSqlTableModel()
|
||||
self._model.setTable("transcription")
|
||||
self._model.setEditStrategy(QSqlTableModel.EditStrategy.OnManualSubmit)
|
||||
self._model.setSort(Column.TIME_QUEUED.value, Qt.SortOrder.DescendingOrder)
|
||||
|
||||
self.setModel(self._model)
|
||||
|
||||
for i in range(self.model().columnCount()):
|
||||
self.hideColumn(i)
|
||||
|
||||
self.settings = Settings()
|
||||
|
||||
self.column_definitions = [
|
||||
TableColDef(
|
||||
id="id",
|
||||
header=_("ID"),
|
||||
column_index=self.Column.TASK_ID.value,
|
||||
value_getter=lambda task: str(task.id),
|
||||
width=0,
|
||||
hidden=True,
|
||||
hidden_toggleable=False,
|
||||
),
|
||||
TableColDef(
|
||||
id="file_name",
|
||||
header=_("File Name/URL"),
|
||||
column_index=self.Column.FILE_NAME.value,
|
||||
value_getter=lambda task: task.url
|
||||
if task.url is not None
|
||||
else os.path.basename(task.file_path),
|
||||
width=300,
|
||||
hidden_toggleable=False,
|
||||
),
|
||||
TableColDef(
|
||||
id="model",
|
||||
header=_("Model"),
|
||||
column_index=self.Column.MODEL.value,
|
||||
value_getter=lambda task: str(task.transcription_options.model),
|
||||
width=180,
|
||||
hidden=True,
|
||||
),
|
||||
TableColDef(
|
||||
id="task",
|
||||
header=_("Task"),
|
||||
column_index=self.Column.TASK.value,
|
||||
value_getter=lambda task: self.get_task_label(task),
|
||||
width=120,
|
||||
hidden=True,
|
||||
),
|
||||
TableColDef(
|
||||
id="status",
|
||||
header=_("Status"),
|
||||
column_index=self.Column.STATUS.value,
|
||||
value_getter=lambda task: task.status_text(),
|
||||
width=180,
|
||||
hidden_toggleable=False,
|
||||
),
|
||||
TableColDef(
|
||||
id="date_added",
|
||||
header=_("Date Added"),
|
||||
column_index=self.Column.DATE_ADDED.value,
|
||||
value_getter=lambda task: task.queued_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if task.queued_at is not None
|
||||
else "",
|
||||
width=180,
|
||||
hidden=True,
|
||||
),
|
||||
TableColDef(
|
||||
id="date_completed",
|
||||
header=_("Date Completed"),
|
||||
column_index=self.Column.DATE_COMPLETED.value,
|
||||
value_getter=lambda task: task.completed_at.strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
if task.completed_at is not None
|
||||
else "",
|
||||
width=180,
|
||||
hidden=True,
|
||||
),
|
||||
]
|
||||
|
||||
self.setColumnCount(len(self.column_definitions))
|
||||
self.verticalHeader().hide()
|
||||
self.setHorizontalHeaderLabels(
|
||||
[definition.header for definition in self.column_definitions]
|
||||
self.settings.begin_group(
|
||||
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
|
||||
)
|
||||
for definition in self.column_definitions:
|
||||
for definition in column_definitions:
|
||||
self.model().setHeaderData(
|
||||
definition.column.value,
|
||||
Qt.Orientation.Horizontal,
|
||||
definition.header,
|
||||
)
|
||||
|
||||
visible = self.settings.settings.value(definition.id, True)
|
||||
self.setColumnHidden(definition.column.value, not visible)
|
||||
if definition.width is not None:
|
||||
self.setColumnWidth(definition.column_index, definition.width)
|
||||
self.load_column_visibility()
|
||||
|
||||
self.horizontalHeader().setMinimumSectionSize(180)
|
||||
self.setColumnWidth(definition.column.value, definition.width)
|
||||
if definition.delegate is not None:
|
||||
self.setItemDelegateForColumn(
|
||||
definition.column.value, definition.delegate
|
||||
)
|
||||
self.settings.end_group()
|
||||
|
||||
self.model().select()
|
||||
self.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
|
||||
self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
|
||||
self.verticalHeader().hide()
|
||||
self.setAlternatingRowColors(True)
|
||||
|
||||
def contextMenuEvent(self, event):
|
||||
menu = QMenu(self)
|
||||
for definition in self.column_definitions:
|
||||
for definition in column_definitions:
|
||||
if not definition.hidden_toggleable:
|
||||
continue
|
||||
action = menu.addAction(definition.header)
|
||||
action.setCheckable(True)
|
||||
action.setChecked(not self.isColumnHidden(definition.column_index))
|
||||
action.setChecked(not self.isColumnHidden(definition.column.value))
|
||||
action.toggled.connect(
|
||||
lambda checked,
|
||||
column_index=definition.column_index: self.on_column_checked(
|
||||
column_index=definition.column.value: self.on_column_checked(
|
||||
column_index, checked
|
||||
)
|
||||
)
|
||||
|
@ -155,66 +211,47 @@ class TranscriptionTasksTableWidget(QTableWidget):
|
|||
self.settings.begin_group(
|
||||
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
|
||||
)
|
||||
for definition in self.column_definitions:
|
||||
for definition in column_definitions:
|
||||
self.settings.settings.setValue(
|
||||
definition.id, not self.isColumnHidden(definition.column_index)
|
||||
definition.id, not self.isColumnHidden(definition.column.value)
|
||||
)
|
||||
self.settings.end_group()
|
||||
|
||||
def load_column_visibility(self):
|
||||
self.settings.begin_group(
|
||||
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
|
||||
)
|
||||
for definition in self.column_definitions:
|
||||
visible = self.settings.settings.value(definition.id, not definition.hidden)
|
||||
self.setColumnHidden(definition.column_index, not visible)
|
||||
self.settings.end_group()
|
||||
|
||||
def upsert_task(self, task: FileTranscriptionTask):
|
||||
task_row_index = self.task_row_index(task.id)
|
||||
if task_row_index is None:
|
||||
self.insertRow(self.rowCount())
|
||||
|
||||
row_index = self.rowCount() - 1
|
||||
for definition in self.column_definitions:
|
||||
item = QTableWidgetItem(definition.value_getter(task))
|
||||
item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||||
self.setItem(row_index, definition.column_index, item)
|
||||
else:
|
||||
for definition in self.column_definitions:
|
||||
item = self.item(task_row_index, definition.column_index)
|
||||
item.setText(definition.value_getter(task))
|
||||
|
||||
@staticmethod
|
||||
def get_task_label(task: FileTranscriptionTask) -> str:
|
||||
value = task.transcription_options.task.value.capitalize()
|
||||
if task.transcription_options.language is not None:
|
||||
value += f" ({humanize_language(task.transcription_options.language)})"
|
||||
return value
|
||||
|
||||
def clear_task(self, task_id: int):
|
||||
task_row_index = self.task_row_index(task_id)
|
||||
if task_row_index is not None:
|
||||
self.removeRow(task_row_index)
|
||||
|
||||
def task_row_index(self, task_id: int) -> int | None:
|
||||
table_items_matching_task_id = [
|
||||
item
|
||||
for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly)
|
||||
if item.column() == self.Column.TASK_ID.value
|
||||
]
|
||||
if len(table_items_matching_task_id) == 0:
|
||||
return None
|
||||
return table_items_matching_task_id[0].row()
|
||||
|
||||
@staticmethod
|
||||
def find_task_id(index: QModelIndex):
|
||||
sibling_index = index.siblingAtColumn(
|
||||
TranscriptionTasksTableWidget.Column.TASK_ID.value
|
||||
).data()
|
||||
return int(sibling_index) if sibling_index is not None else None
|
||||
|
||||
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
|
||||
if event.key() == Qt.Key.Key_Return:
|
||||
self.return_clicked.emit()
|
||||
super().keyPressEvent(event)
|
||||
|
||||
def selected_transcriptions(self) -> List[Transcription]:
|
||||
selected = self.selectionModel().selectedRows()
|
||||
return [self.transcription(row) for row in selected]
|
||||
|
||||
def delete_transcriptions(self, rows: List[QModelIndex]):
|
||||
for row in rows:
|
||||
self.model().removeRow(row.row())
|
||||
self.model().submitAll()
|
||||
|
||||
def transcription(self, index: QModelIndex) -> Transcription:
|
||||
return Transcription.from_record(self.model().record(index.row()))
|
||||
|
||||
def refresh_all(self):
|
||||
self.model().select()
|
||||
|
||||
def refresh_row(self, id: UUID):
|
||||
for i in range(self.model().rowCount()):
|
||||
record = self.model().record(i)
|
||||
if record.value("id") == str(id):
|
||||
self.model().selectRow(i)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def format_timedelta(delta: timedelta):
|
||||
mm, ss = divmod(delta.seconds, 60)
|
||||
result = f"{ss}s"
|
||||
if mm == 0:
|
||||
return result
|
||||
hh, mm = divmod(mm, 60)
|
||||
result = f"{mm}m {result}"
|
||||
if hh == 0:
|
||||
return result
|
||||
return f"{hh}h {result}"
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtGui import QAction
|
||||
from PyQt6.QtWidgets import QPushButton, QWidget, QMenu, QFileDialog
|
||||
from PyQt6.QtWidgets import QPushButton, QWidget, QMenu
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.transcriber.file_transcriber import write_output
|
||||
from buzz.transcriber.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
OutputFormat,
|
||||
get_output_file_path,
|
||||
)
|
||||
from buzz.widgets.icon import FileDownloadIcon
|
||||
|
||||
|
||||
class ExportTranscriptionButton(QPushButton):
|
||||
def __init__(self, transcription_task: FileTranscriptionTask, parent: QWidget):
|
||||
on_export_triggered = pyqtSignal(OutputFormat)
|
||||
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(parent)
|
||||
self.transcription_task = transcription_task
|
||||
|
||||
export_button_menu = QMenu()
|
||||
actions = [
|
||||
|
@ -29,23 +27,4 @@ class ExportTranscriptionButton(QPushButton):
|
|||
|
||||
def on_menu_triggered(self, action: QAction):
|
||||
output_format = OutputFormat[action.text()]
|
||||
|
||||
default_path = get_output_file_path(
|
||||
task=self.transcription_task, output_format=output_format
|
||||
)
|
||||
|
||||
(output_file_path, nil) = QFileDialog.getSaveFileName(
|
||||
self,
|
||||
_("Save File"),
|
||||
default_path,
|
||||
_("Text files") + f" (*.{output_format.value})",
|
||||
)
|
||||
|
||||
if output_file_path == "":
|
||||
return
|
||||
|
||||
write_output(
|
||||
path=output_file_path,
|
||||
segments=self.transcription_task.segments,
|
||||
output_format=output_format,
|
||||
)
|
||||
self.on_export_triggered.emit(output_format)
|
||||
|
|
|
@ -1,72 +1,105 @@
|
|||
import enum
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from PyQt6.QtCore import Qt, pyqtSignal
|
||||
from PyQt6.QtWidgets import QTableWidget, QWidget, QHeaderView, QTableWidgetItem
|
||||
from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex, QItemSelection
|
||||
from PyQt6.QtSql import QSqlTableModel, QSqlRecord
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QTableView,
|
||||
QStyledItemDelegate,
|
||||
QAbstractItemView,
|
||||
)
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.transcriber.file_transcriber import to_timestamp
|
||||
from buzz.transcriber.transcriber import Segment
|
||||
|
||||
|
||||
class TranscriptionSegmentsEditorWidget(QTableWidget):
|
||||
segment_text_changed = pyqtSignal(tuple)
|
||||
segment_index_selected = pyqtSignal(int)
|
||||
class Column(enum.Enum):
|
||||
ID = 0
|
||||
END = enum.auto()
|
||||
START = enum.auto()
|
||||
TEXT = enum.auto()
|
||||
TRANSCRIPTION_ID = enum.auto()
|
||||
|
||||
class Column(enum.Enum):
|
||||
START = 0
|
||||
END = enum.auto()
|
||||
TEXT = enum.auto()
|
||||
|
||||
def __init__(self, segments: List[Segment], parent: Optional[QWidget]):
|
||||
@dataclass
|
||||
class ColDef:
|
||||
id: str
|
||||
header: str
|
||||
column: Column
|
||||
delegate: Optional[QStyledItemDelegate] = None
|
||||
|
||||
|
||||
class TimeStampDelegate(QStyledItemDelegate):
|
||||
def displayText(self, value, locale):
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class TranscriptionSegmentModel(QSqlTableModel):
|
||||
def __init__(self, transcription_id: UUID):
|
||||
super().__init__()
|
||||
self.setTable("transcription_segment")
|
||||
self.setEditStrategy(QSqlTableModel.EditStrategy.OnFieldChange)
|
||||
self.setFilter(f"transcription_id = '{transcription_id}'")
|
||||
|
||||
def flags(self, index: QModelIndex):
|
||||
flags = super().flags(index)
|
||||
if index.column() in (Column.START.value, Column.END.value):
|
||||
flags &= ~Qt.ItemFlag.ItemIsEditable
|
||||
return flags
|
||||
|
||||
|
||||
class TranscriptionSegmentsEditorWidget(QTableView):
|
||||
segment_selected = pyqtSignal(QSqlRecord)
|
||||
|
||||
def __init__(self, transcription_id: UUID, parent: Optional[QWidget]):
|
||||
super().__init__(parent)
|
||||
|
||||
self.segments = segments
|
||||
model = TranscriptionSegmentModel(transcription_id=transcription_id)
|
||||
self.setModel(model)
|
||||
|
||||
timestamp_delegate = TimeStampDelegate()
|
||||
|
||||
self.column_definitions: list[ColDef] = [
|
||||
ColDef("start", _("Start"), Column.START, delegate=timestamp_delegate),
|
||||
ColDef("end", _("End"), Column.END, delegate=timestamp_delegate),
|
||||
ColDef("text", _("Text"), Column.TEXT),
|
||||
]
|
||||
|
||||
for i in range(model.columnCount()):
|
||||
self.hideColumn(i)
|
||||
|
||||
for definition in self.column_definitions:
|
||||
model.setHeaderData(
|
||||
definition.column.value,
|
||||
Qt.Orientation.Horizontal,
|
||||
definition.header,
|
||||
)
|
||||
self.showColumn(definition.column.value)
|
||||
if definition.delegate is not None:
|
||||
self.setItemDelegateForColumn(
|
||||
definition.column.value, definition.delegate
|
||||
)
|
||||
|
||||
self.setAlternatingRowColors(True)
|
||||
|
||||
self.setColumnCount(3)
|
||||
|
||||
self.verticalHeader().hide()
|
||||
self.setHorizontalHeaderLabels([_("Start"), _("End"), _("Text")])
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
2, QHeaderView.ResizeMode.ResizeToContents
|
||||
)
|
||||
self.setSelectionMode(QTableWidget.SelectionMode.SingleSelection)
|
||||
self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
|
||||
self.setSelectionMode(QTableView.SelectionMode.SingleSelection)
|
||||
self.selectionModel().selectionChanged.connect(self.on_selection_changed)
|
||||
model.select()
|
||||
|
||||
for segment in segments:
|
||||
row_index = self.rowCount()
|
||||
self.insertRow(row_index)
|
||||
self.resizeColumnsToContents()
|
||||
|
||||
start_item = QTableWidgetItem(to_timestamp(segment.start))
|
||||
start_item.setFlags(
|
||||
start_item.flags()
|
||||
& ~Qt.ItemFlag.ItemIsEditable
|
||||
& ~Qt.ItemFlag.ItemIsSelectable
|
||||
)
|
||||
self.setItem(row_index, self.Column.START.value, start_item)
|
||||
def on_selection_changed(
|
||||
self, selected: QItemSelection, _deselected: QItemSelection
|
||||
):
|
||||
if selected.indexes():
|
||||
self.segment_selected.emit(self.segment(selected.indexes()[0]))
|
||||
|
||||
end_item = QTableWidgetItem(to_timestamp(segment.end))
|
||||
end_item.setFlags(
|
||||
end_item.flags()
|
||||
& ~Qt.ItemFlag.ItemIsEditable
|
||||
& ~Qt.ItemFlag.ItemIsSelectable
|
||||
)
|
||||
self.setItem(row_index, self.Column.END.value, end_item)
|
||||
def segment(self, index: QModelIndex) -> QSqlRecord:
|
||||
return self.model().record(index.row())
|
||||
|
||||
text_item = QTableWidgetItem(segment.text)
|
||||
self.setItem(row_index, self.Column.TEXT.value, text_item)
|
||||
|
||||
self.itemChanged.connect(self.on_item_changed)
|
||||
self.itemSelectionChanged.connect(self.on_item_selection_changed)
|
||||
|
||||
def on_item_changed(self, item: QTableWidgetItem):
|
||||
if item.column() == self.Column.TEXT.value:
|
||||
self.segment_text_changed.emit((item.row(), item.text()))
|
||||
|
||||
def set_segment_text(self, index: int, text: str):
|
||||
self.item(index, self.Column.TEXT.value).setText(text)
|
||||
|
||||
def on_item_selection_changed(self):
|
||||
ranges = self.selectedRanges()
|
||||
self.segment_index_selected.emit(ranges[0].topRow() if len(ranges) > 0 else -1)
|
||||
def segments(self) -> list[QSqlRecord]:
|
||||
return [self.model().record(i) for i in range(self.model().rowCount())]
|
||||
|
|
|
@ -1,25 +1,24 @@
|
|||
import platform
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from PyQt6.QtCore import Qt, pyqtSignal
|
||||
from PyQt6.QtGui import QUndoCommand, QUndoStack, QKeySequence
|
||||
from PyQt6.QtCore import Qt
|
||||
from PyQt6.QtMultimedia import QMediaPlayer
|
||||
from PyQt6.QtSql import QSqlRecord
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGridLayout,
|
||||
QFileDialog,
|
||||
)
|
||||
|
||||
from buzz.action import Action
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
from buzz.locale import _
|
||||
from buzz.paths import file_path_as_title
|
||||
from buzz.transcriber.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
Segment,
|
||||
)
|
||||
from buzz.transcriber.file_transcriber import write_output
|
||||
from buzz.transcriber.transcriber import OutputFormat, Segment
|
||||
from buzz.widgets.audio_player import AudioPlayer
|
||||
from buzz.widgets.icon import UndoIcon, RedoIcon
|
||||
from buzz.widgets.toolbar import ToolBar
|
||||
from buzz.widgets.transcription_viewer.export_transcription_button import (
|
||||
ExportTranscriptionButton,
|
||||
)
|
||||
|
@ -28,84 +27,31 @@ from buzz.widgets.transcription_viewer.transcription_segments_editor_widget impo
|
|||
)
|
||||
|
||||
|
||||
class ChangeSegmentTextCommand(QUndoCommand):
|
||||
def __init__(
|
||||
self,
|
||||
table_widget: TranscriptionSegmentsEditorWidget,
|
||||
segments: List[Segment],
|
||||
segment_index: int,
|
||||
segment_text: str,
|
||||
task_changed: pyqtSignal,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.table_widget = table_widget
|
||||
self.segments = segments
|
||||
self.segment_index = segment_index
|
||||
self.segment_text = segment_text
|
||||
self.task_changed = task_changed
|
||||
|
||||
self.previous_segment_text = self.segments[self.segment_index].text
|
||||
|
||||
def undo(self) -> None:
|
||||
self.set_segment_text(self.previous_segment_text)
|
||||
|
||||
def redo(self) -> None:
|
||||
self.set_segment_text(self.segment_text)
|
||||
|
||||
def set_segment_text(self, text: str):
|
||||
# block signals before setting text so it doesn't re-trigger a new UndoCommand
|
||||
self.table_widget.blockSignals(True)
|
||||
self.table_widget.set_segment_text(self.segment_index, text)
|
||||
self.table_widget.blockSignals(False)
|
||||
self.segments[self.segment_index].text = text
|
||||
self.task_changed.emit()
|
||||
|
||||
|
||||
class TranscriptionViewerWidget(QWidget):
|
||||
transcription_task: FileTranscriptionTask
|
||||
task_changed = pyqtSignal()
|
||||
transcription: Transcription
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transcription_task: FileTranscriptionTask,
|
||||
open_transcription_output=True,
|
||||
transcription: Transcription,
|
||||
parent: Optional["QWidget"] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget,
|
||||
) -> None:
|
||||
super().__init__(parent, flags)
|
||||
self.transcription_task = transcription_task
|
||||
self.open_transcription_output = open_transcription_output
|
||||
self.transcription = transcription
|
||||
|
||||
self.setMinimumWidth(800)
|
||||
self.setMinimumHeight(500)
|
||||
|
||||
self.setWindowTitle(file_path_as_title(transcription_task.file_path))
|
||||
|
||||
self.undo_stack = QUndoStack()
|
||||
|
||||
undo_action = self.undo_stack.createUndoAction(self, _("Undo"))
|
||||
undo_action.setShortcuts(QKeySequence.StandardKey.Undo)
|
||||
undo_action.setIcon(UndoIcon(parent=self))
|
||||
undo_action.setToolTip(Action.get_tooltip(undo_action))
|
||||
|
||||
redo_action = self.undo_stack.createRedoAction(self, _("Redo"))
|
||||
redo_action.setShortcuts(QKeySequence.StandardKey.Redo)
|
||||
redo_action.setIcon(RedoIcon(parent=self))
|
||||
redo_action.setToolTip(Action.get_tooltip(redo_action))
|
||||
|
||||
toolbar = ToolBar()
|
||||
toolbar.addActions([undo_action, redo_action])
|
||||
self.setWindowTitle(file_path_as_title(transcription.file))
|
||||
|
||||
self.table_widget = TranscriptionSegmentsEditorWidget(
|
||||
segments=transcription_task.segments, parent=self
|
||||
transcription_id=UUID(hex=transcription.id), parent=self
|
||||
)
|
||||
self.table_widget.segment_text_changed.connect(self.on_segment_text_changed)
|
||||
self.table_widget.segment_index_selected.connect(self.on_segment_index_selected)
|
||||
self.table_widget.segment_selected.connect(self.on_segment_selected)
|
||||
|
||||
self.audio_player: Optional[AudioPlayer] = None
|
||||
if platform.system() != "Linux":
|
||||
self.audio_player = AudioPlayer(file_path=transcription_task.file_path)
|
||||
self.audio_player = AudioPlayer(file_path=transcription.file)
|
||||
self.audio_player.position_ms_changed.connect(
|
||||
self.on_audio_player_position_ms_changed
|
||||
)
|
||||
|
@ -118,12 +64,10 @@ class TranscriptionViewerWidget(QWidget):
|
|||
buttons_layout = QHBoxLayout()
|
||||
buttons_layout.addStretch()
|
||||
|
||||
export_button = ExportTranscriptionButton(
|
||||
transcription_task=transcription_task, parent=self
|
||||
)
|
||||
export_button = ExportTranscriptionButton(parent=self)
|
||||
export_button.on_export_triggered.connect(self.on_export_triggered)
|
||||
|
||||
layout = QGridLayout(self)
|
||||
layout.setMenuBar(toolbar)
|
||||
layout.addWidget(self.table_widget, 0, 0, 1, 2)
|
||||
|
||||
if self.audio_player is not None:
|
||||
|
@ -133,33 +77,56 @@ class TranscriptionViewerWidget(QWidget):
|
|||
|
||||
self.setLayout(layout)
|
||||
|
||||
def on_segment_text_changed(self, event: tuple):
|
||||
segment_index, segment_text = event
|
||||
self.undo_stack.push(
|
||||
ChangeSegmentTextCommand(
|
||||
table_widget=self.table_widget,
|
||||
segments=self.transcription_task.segments,
|
||||
segment_index=segment_index,
|
||||
segment_text=segment_text,
|
||||
task_changed=self.task_changed,
|
||||
)
|
||||
def on_export_triggered(self, output_format: OutputFormat) -> None:
|
||||
default_path = self.transcription.get_output_file_path(
|
||||
output_format=output_format
|
||||
)
|
||||
|
||||
def on_segment_index_selected(self, index: int):
|
||||
selected_segment = self.transcription_task.segments[index]
|
||||
if self.audio_player is not None:
|
||||
self.audio_player.set_range((selected_segment.start, selected_segment.end))
|
||||
(output_file_path, nil) = QFileDialog.getSaveFileName(
|
||||
self,
|
||||
_("Save File"),
|
||||
default_path,
|
||||
_("Text files") + f" (*.{output_format.value})",
|
||||
)
|
||||
|
||||
if output_file_path == "":
|
||||
return
|
||||
|
||||
segments = [
|
||||
Segment(
|
||||
start=segment.value("start_time"),
|
||||
end=segment.value("end_time"),
|
||||
text=segment.value("text"),
|
||||
)
|
||||
for segment in self.table_widget.segments()
|
||||
]
|
||||
|
||||
write_output(
|
||||
path=output_file_path,
|
||||
segments=segments,
|
||||
output_format=output_format,
|
||||
)
|
||||
|
||||
def on_segment_selected(self, segment: QSqlRecord):
|
||||
if self.audio_player is not None and (
|
||||
self.audio_player.media_player.playbackState()
|
||||
== QMediaPlayer.PlaybackState.PlayingState
|
||||
):
|
||||
self.audio_player.set_range(
|
||||
(segment.value("start_time"), segment.value("end_time"))
|
||||
)
|
||||
|
||||
def on_audio_player_position_ms_changed(self, position_ms: int) -> None:
|
||||
current_segment_index: Optional[int] = next(
|
||||
segments = self.table_widget.segments()
|
||||
current_segment = next(
|
||||
(
|
||||
i
|
||||
for i, segment in enumerate(self.transcription_task.segments)
|
||||
if segment.start <= position_ms < segment.end
|
||||
segment
|
||||
for segment in segments
|
||||
if segment.value("start_time")
|
||||
<= position_ms
|
||||
< segment.value("end_time")
|
||||
),
|
||||
None,
|
||||
)
|
||||
if current_segment_index is not None:
|
||||
self.current_segment_label.setText(
|
||||
self.transcription_task.segments[current_segment_index].text
|
||||
)
|
||||
if current_segment is not None:
|
||||
self.current_segment_label.setText(current_segment.value("text"))
|
||||
|
|
1123
poetry.lock
generated
1123
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -14,7 +14,6 @@ packages = [
|
|||
[tool.poetry.dependencies]
|
||||
python = ">=3.9.13,<3.11"
|
||||
sounddevice = "^0.4.5"
|
||||
appdirs = "^1.4.4"
|
||||
humanize = "^4.4.0"
|
||||
PyQt6 = "^6.4.0"
|
||||
openai = "^1.6.1"
|
||||
|
|
40
tests/conftest.py
Normal file
40
tests/conftest.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from PyQt6.QtSql import QSqlDatabase
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from buzz.db.dao.transcription_dao import TranscriptionDAO
|
||||
from buzz.db.dao.transcription_segment_dao import TranscriptionSegmentDAO
|
||||
from buzz.db.db import setup_test_db
|
||||
from buzz.db.service.transcription_service import TranscriptionService
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db() -> QSqlDatabase:
|
||||
db = setup_test_db()
|
||||
yield db
|
||||
db.close()
|
||||
os.remove(db.databaseName())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def transcription_dao(db, request: SubRequest) -> TranscriptionDAO:
|
||||
dao = TranscriptionDAO(db)
|
||||
if hasattr(request, "param"):
|
||||
transcriptions = request.param
|
||||
for transcription in transcriptions:
|
||||
dao.insert(transcription)
|
||||
return dao
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def transcription_service(
|
||||
transcription_dao, transcription_segment_dao
|
||||
) -> TranscriptionService:
|
||||
return TranscriptionService(transcription_dao, transcription_segment_dao)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def transcription_segment_dao(db) -> TranscriptionSegmentDAO:
|
||||
return TranscriptionSegmentDAO(db)
|
|
@ -54,13 +54,15 @@ class TestWhisperFileTranscriber:
|
|||
expected_file_path: str,
|
||||
):
|
||||
file_path = get_output_file_path(
|
||||
task=FileTranscriptionTask(
|
||||
file_path=file_path,
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
|
||||
model_path="",
|
||||
file_path=file_path,
|
||||
language=None,
|
||||
task=Task.TRANSLATE,
|
||||
model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
output_format=output_format,
|
||||
output_directory="",
|
||||
export_file_name_template="{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}",
|
||||
)
|
||||
assert file_path == expected_file_path
|
||||
|
@ -87,15 +89,15 @@ class TestWhisperFileTranscriber:
|
|||
"{{ input_file_name }} (Translated on {{ date_time }})"
|
||||
)
|
||||
srt = get_output_file_path(
|
||||
task=FileTranscriptionTask(
|
||||
file_path=file_path,
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=[],
|
||||
),
|
||||
model_path="",
|
||||
file_path=file_path,
|
||||
language=None,
|
||||
task=Task.TRANSLATE,
|
||||
model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
output_format=OutputFormat.TXT,
|
||||
output_directory="",
|
||||
export_file_name_template=export_file_name_template,
|
||||
)
|
||||
|
||||
|
@ -103,15 +105,15 @@ class TestWhisperFileTranscriber:
|
|||
assert srt.endswith(".txt")
|
||||
|
||||
srt = get_output_file_path(
|
||||
task=FileTranscriptionTask(
|
||||
file_path=file_path,
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=[],
|
||||
),
|
||||
model_path="",
|
||||
file_path=file_path,
|
||||
language=None,
|
||||
task=Task.TRANSLATE,
|
||||
model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
output_format=OutputFormat.SRT,
|
||||
output_directory="",
|
||||
export_file_name_template=export_file_name_template,
|
||||
)
|
||||
assert srt.startswith(expected_starts_with)
|
||||
|
|
|
@ -6,50 +6,26 @@ import pytest
|
|||
from PyQt6.QtCore import QSize, Qt
|
||||
from PyQt6.QtGui import QKeyEvent, QAction
|
||||
from PyQt6.QtWidgets import (
|
||||
QTableWidget,
|
||||
QMessageBox,
|
||||
QPushButton,
|
||||
QToolBar,
|
||||
QMenuBar,
|
||||
QTableView,
|
||||
)
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
from buzz.transcriber.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
TranscriptionOptions,
|
||||
FileTranscriptionOptions,
|
||||
)
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
from buzz.db.service.transcription_service import TranscriptionService
|
||||
from buzz.widgets.main_window import MainWindow
|
||||
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
|
||||
from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
|
||||
TranscriptionViewerWidget,
|
||||
)
|
||||
|
||||
mock_tasks = [
|
||||
FileTranscriptionTask(
|
||||
file_path="",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.COMPLETED,
|
||||
),
|
||||
FileTranscriptionTask(
|
||||
file_path="",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.CANCELED,
|
||||
),
|
||||
FileTranscriptionTask(
|
||||
file_path="",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.FAILED,
|
||||
error="Error",
|
||||
),
|
||||
mock_transcriptions: List[Transcription] = [
|
||||
Transcription(status="completed"),
|
||||
Transcription(status="canceled"),
|
||||
Transcription(status="failed", error_message="Error"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -58,37 +34,41 @@ def get_test_asset(filename: str):
|
|||
|
||||
|
||||
class TestMainWindow:
|
||||
def test_should_set_window_title_and_icon(self, qtbot):
|
||||
window = MainWindow()
|
||||
def test_should_set_window_title_and_icon(self, qtbot, transcription_service):
|
||||
window = MainWindow(transcription_service)
|
||||
qtbot.add_widget(window)
|
||||
assert window.windowTitle() == "Buzz"
|
||||
assert window.windowIcon().pixmap(QSize(64, 64)).isNull() is False
|
||||
window.close()
|
||||
|
||||
def test_should_run_file_transcription_task(self, qtbot: QtBot, tasks_cache):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
def test_should_run_file_transcription_task(
|
||||
self, qtbot: QtBot, transcription_service
|
||||
):
|
||||
window = MainWindow(transcription_service)
|
||||
|
||||
self.import_file_and_start_transcription(window)
|
||||
self._import_file_and_start_transcription(window)
|
||||
|
||||
open_transcript_action = self._get_toolbar_action(window, "Open Transcript")
|
||||
assert open_transcript_action.isEnabled() is False
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
table_widget = self._get_tasks_table(window)
|
||||
qtbot.wait_until(
|
||||
self.get_assert_task_status_callback(table_widget, 0, "Completed"),
|
||||
self._get_assert_task_status_callback(table_widget, 0, "completed"),
|
||||
timeout=2 * 60 * 1000,
|
||||
)
|
||||
|
||||
table_widget.setCurrentIndex(
|
||||
table_widget.indexFromItem(table_widget.item(0, 1))
|
||||
)
|
||||
table_widget.setCurrentIndex(table_widget.model().index(0, 0))
|
||||
assert open_transcript_action.isEnabled()
|
||||
window.close()
|
||||
|
||||
@staticmethod
|
||||
def _get_tasks_table(window: MainWindow) -> QTableView:
|
||||
return window.findChild(QTableView)
|
||||
|
||||
def test_should_run_url_import_file_transcription_task(
|
||||
self, qtbot: QtBot, tasks_cache
|
||||
self, qtbot: QtBot, db, transcription_service
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
window = MainWindow(transcription_service)
|
||||
menu: QMenuBar = window.menuBar()
|
||||
file_action = menu.actions()[0]
|
||||
import_url_action: QAction = file_action.menu().actions()[1]
|
||||
|
@ -105,24 +85,26 @@ class TestMainWindow:
|
|||
run_button: QPushButton = file_transcriber_widget.findChild(QPushButton)
|
||||
run_button.click()
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
table_widget = self._get_tasks_table(window)
|
||||
qtbot.wait_until(
|
||||
self.get_assert_task_status_callback(table_widget, 0, "Completed"),
|
||||
self._get_assert_task_status_callback(table_widget, 0, "completed"),
|
||||
timeout=2 * 60 * 1000,
|
||||
)
|
||||
|
||||
window.close()
|
||||
|
||||
def test_should_run_and_cancel_transcription_task(self, qtbot, tasks_cache):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
def test_should_run_and_cancel_transcription_task(
|
||||
self, qtbot, db, transcription_service
|
||||
):
|
||||
window = MainWindow(transcription_service)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
self.import_file_and_start_transcription(window, long_audio=True)
|
||||
self._import_file_and_start_transcription(window, long_audio=True)
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
table_widget = self._get_tasks_table(window)
|
||||
|
||||
qtbot.wait_until(
|
||||
self.get_assert_task_status_callback(table_widget, 0, "In Progress"),
|
||||
self._get_assert_task_status_callback(table_widget, 0, "in_progress"),
|
||||
timeout=2 * 60 * 1000,
|
||||
)
|
||||
|
||||
|
@ -131,7 +113,7 @@ class TestMainWindow:
|
|||
window.toolbar.stop_transcription_action.trigger()
|
||||
|
||||
qtbot.wait_until(
|
||||
self.get_assert_task_status_callback(table_widget, 0, "Canceled"),
|
||||
self._get_assert_task_status_callback(table_widget, 0, "canceled"),
|
||||
timeout=60 * 1000,
|
||||
)
|
||||
|
||||
|
@ -141,60 +123,72 @@ class TestMainWindow:
|
|||
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
def test_should_load_tasks_from_cache(self, qtbot, tasks_cache):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
|
||||
def test_should_load_tasks_from_cache(
|
||||
self, qtbot, transcription_dao, transcription_segment_dao
|
||||
):
|
||||
window = MainWindow(
|
||||
TranscriptionService(transcription_dao, transcription_segment_dao)
|
||||
)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
assert table_widget.rowCount() == 3
|
||||
table_widget = self._get_tasks_table(window)
|
||||
assert table_widget.model().rowCount() == 3
|
||||
|
||||
assert table_widget.item(0, 4).text() == "Completed"
|
||||
assert self._get_status(table_widget, 0) == "completed"
|
||||
table_widget.selectRow(0)
|
||||
assert window.toolbar.open_transcript_action.isEnabled()
|
||||
|
||||
assert table_widget.item(1, 4).text() == "Canceled"
|
||||
assert self._get_status(table_widget, 1) == "canceled"
|
||||
table_widget.selectRow(1)
|
||||
assert window.toolbar.open_transcript_action.isEnabled() is False
|
||||
|
||||
assert table_widget.item(2, 4).text() == "Failed (Error)"
|
||||
assert self._get_status(table_widget, 2) == "failed"
|
||||
table_widget.selectRow(2)
|
||||
assert window.toolbar.open_transcript_action.isEnabled() is False
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
def test_should_clear_history_with_rows_selected(self, qtbot, tasks_cache):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
|
||||
def test_should_clear_history_with_rows_selected(
|
||||
self, qtbot, transcription_dao, transcription_segment_dao
|
||||
):
|
||||
window = MainWindow(
|
||||
TranscriptionService(transcription_dao, transcription_segment_dao)
|
||||
)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
table_widget = self._get_tasks_table(window)
|
||||
table_widget.selectAll()
|
||||
|
||||
with patch("PyQt6.QtWidgets.QMessageBox.question") as question_message_box_mock:
|
||||
question_message_box_mock.return_value = QMessageBox.StandardButton.Yes
|
||||
window.toolbar.clear_history_action.trigger()
|
||||
|
||||
assert table_widget.rowCount() == 0
|
||||
assert table_widget.model().rowCount() == 0
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
|
||||
def test_should_have_clear_history_action_disabled_with_no_rows_selected(
|
||||
self, qtbot, tasks_cache
|
||||
self, qtbot, transcription_dao, transcription_segment_dao
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
window = MainWindow(
|
||||
TranscriptionService(transcription_dao, transcription_segment_dao)
|
||||
)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
assert window.toolbar.clear_history_action.isEnabled() is False
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
|
||||
def test_should_open_transcription_viewer_when_menu_action_is_clicked(
|
||||
self, qtbot, tasks_cache
|
||||
self, qtbot, transcription_dao, transcription_segment_dao
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
window = MainWindow(
|
||||
TranscriptionService(transcription_dao, transcription_segment_dao)
|
||||
)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
|
||||
table_widget = self._get_tasks_table(window)
|
||||
table_widget.selectRow(0)
|
||||
|
||||
window.toolbar.open_transcript_action.trigger()
|
||||
|
@ -204,14 +198,16 @@ class TestMainWindow:
|
|||
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
|
||||
def test_should_open_transcription_viewer_when_return_clicked(
|
||||
self, qtbot, tasks_cache
|
||||
self, qtbot, transcription_dao, transcription_segment_dao
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
window = MainWindow(
|
||||
TranscriptionService(transcription_dao, transcription_segment_dao)
|
||||
)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
table_widget = self._get_tasks_table(window)
|
||||
table_widget.selectRow(0)
|
||||
table_widget.keyPressEvent(
|
||||
QKeyEvent(
|
||||
|
@ -227,18 +223,20 @@ class TestMainWindow:
|
|||
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
|
||||
def test_should_have_open_transcript_action_disabled_with_no_rows_selected(
|
||||
self, qtbot, tasks_cache
|
||||
self, qtbot, transcription_dao, transcription_segment_dao
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
window = MainWindow(
|
||||
TranscriptionService(transcription_dao, transcription_segment_dao)
|
||||
)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
assert window.toolbar.open_transcript_action.isEnabled() is False
|
||||
window.close()
|
||||
|
||||
@staticmethod
|
||||
def import_file_and_start_transcription(
|
||||
def _import_file_and_start_transcription(
|
||||
window: MainWindow, long_audio: bool = False
|
||||
):
|
||||
with patch(
|
||||
|
@ -264,34 +262,24 @@ class TestMainWindow:
|
|||
run_button.click()
|
||||
|
||||
@staticmethod
|
||||
def get_assert_task_status_callback(
|
||||
table_widget: QTableWidget,
|
||||
def _get_assert_task_status_callback(
|
||||
table_widget: QTableView,
|
||||
row_index: int,
|
||||
expected_status: str,
|
||||
long_audio: bool = False,
|
||||
):
|
||||
def assert_task_status():
|
||||
assert table_widget.rowCount() > 0
|
||||
assert (
|
||||
table_widget.item(row_index, 1).text() == "audio-long.mp3"
|
||||
if long_audio
|
||||
else "whisper-french.mp3"
|
||||
assert table_widget.model().rowCount() > 0
|
||||
assert expected_status in TestMainWindow._get_status(
|
||||
table_widget, row_index
|
||||
)
|
||||
assert expected_status in table_widget.item(row_index, 4).text()
|
||||
|
||||
return assert_task_status
|
||||
|
||||
@staticmethod
|
||||
def _get_status(table_widget: QTableView, row_index: int):
|
||||
return table_widget.model().index(row_index, 9).data()
|
||||
|
||||
@staticmethod
|
||||
def _get_toolbar_action(window: MainWindow, text: str):
|
||||
toolbar: QToolBar = window.findChild(QToolBar)
|
||||
return [action for action in toolbar.actions() if action.text() == text][0]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tasks_cache(tmp_path, request: SubRequest):
|
||||
cache = TasksCache(cache_dir=str(tmp_path))
|
||||
if hasattr(request, "param"):
|
||||
tasks: List[FileTranscriptionTask] = request.param
|
||||
cache.save(tasks)
|
||||
yield cache
|
||||
cache.clear()
|
||||
|
|
|
@ -1,115 +1,7 @@
|
|||
import datetime
|
||||
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.transcriber.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
TranscriptionOptions,
|
||||
FileTranscriptionOptions,
|
||||
)
|
||||
from buzz.widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
|
||||
|
||||
|
||||
class TestTranscriptionTasksTableWidget:
|
||||
def test_upsert_task(self, qtbot: QtBot):
|
||||
def test_can_create(self, qtbot, reset_settings):
|
||||
widget = TranscriptionTasksTableWidget()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
task = FileTranscriptionTask(
|
||||
id=0,
|
||||
file_path="testdata/whisper-french.mp3",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.QUEUED,
|
||||
)
|
||||
task.queued_at = datetime.datetime(2023, 4, 12, 0, 0, 0)
|
||||
task.started_at = datetime.datetime(2023, 4, 12, 0, 0, 5)
|
||||
|
||||
widget.upsert_task(task)
|
||||
|
||||
assert widget.rowCount() == 1
|
||||
self.assert_row_text(
|
||||
widget, 0, "whisper-french.mp3", "Whisper (Tiny)", "Transcribe", "Queued"
|
||||
)
|
||||
|
||||
task.status = FileTranscriptionTask.Status.IN_PROGRESS
|
||||
task.fraction_completed = 0.3524
|
||||
widget.upsert_task(task)
|
||||
|
||||
assert widget.rowCount() == 1
|
||||
self.assert_row_text(
|
||||
widget,
|
||||
0,
|
||||
"whisper-french.mp3",
|
||||
"Whisper (Tiny)",
|
||||
"Transcribe",
|
||||
"In Progress (35%)",
|
||||
)
|
||||
|
||||
task.status = FileTranscriptionTask.Status.COMPLETED
|
||||
task.completed_at = datetime.datetime(2023, 4, 12, 0, 0, 10)
|
||||
widget.upsert_task(task)
|
||||
|
||||
assert widget.rowCount() == 1
|
||||
self.assert_row_text(
|
||||
widget,
|
||||
0,
|
||||
"whisper-french.mp3",
|
||||
"Whisper (Tiny)",
|
||||
"Transcribe",
|
||||
"Completed (5s)",
|
||||
)
|
||||
|
||||
def test_upsert_task_no_timings(self, qtbot: QtBot):
|
||||
widget = TranscriptionTasksTableWidget()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
task = FileTranscriptionTask(
|
||||
id=0,
|
||||
file_path="testdata/whisper-french.mp3",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.COMPLETED,
|
||||
)
|
||||
widget.upsert_task(task)
|
||||
|
||||
assert widget.rowCount() == 1
|
||||
self.assert_row_text(
|
||||
widget, 0, "whisper-french.mp3", "Whisper (Tiny)", "Transcribe", "Completed"
|
||||
)
|
||||
|
||||
def test_toggle_column_visibility(self, qtbot, reset_settings):
|
||||
widget = TranscriptionTasksTableWidget()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.TASK_ID.value)
|
||||
assert not widget.isColumnHidden(
|
||||
TranscriptionTasksTableWidget.Column.FILE_NAME.value
|
||||
)
|
||||
assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.MODEL.value)
|
||||
assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.TASK.value)
|
||||
assert not widget.isColumnHidden(
|
||||
TranscriptionTasksTableWidget.Column.STATUS.value
|
||||
)
|
||||
|
||||
# TODO: open context menu and toggle column visibility
|
||||
|
||||
def assert_row_text(
|
||||
self,
|
||||
widget: TranscriptionTasksTableWidget,
|
||||
row: int,
|
||||
filename: str,
|
||||
model: str,
|
||||
task: str,
|
||||
status: str,
|
||||
):
|
||||
assert widget.item(row, 1).text() == filename
|
||||
assert widget.item(row, 2).text() == model
|
||||
assert widget.item(row, 3).text() == task
|
||||
assert widget.item(row, 4).text() == status
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
import pathlib
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from PyQt6.QtWidgets import QPushButton, QToolBar
|
||||
from PyQt6.QtWidgets import QPushButton
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.transcriber.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
FileTranscriptionOptions,
|
||||
TranscriptionOptions,
|
||||
Segment,
|
||||
)
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
from buzz.db.entity.transcription_segment import TranscriptionSegment
|
||||
from buzz.model_loader import ModelType, WhisperModelSize
|
||||
from buzz.transcriber.transcriber import Task
|
||||
from buzz.widgets.transcription_viewer.transcription_segments_editor_widget import (
|
||||
TranscriptionSegmentsEditorWidget,
|
||||
)
|
||||
|
@ -21,22 +20,29 @@ from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
|
|||
|
||||
class TestTranscriptionViewerWidget:
|
||||
@pytest.fixture()
|
||||
def task(self) -> FileTranscriptionTask:
|
||||
return FileTranscriptionTask(
|
||||
id=0,
|
||||
file_path="testdata/whisper-french.mp3",
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
),
|
||||
transcription_options=TranscriptionOptions(),
|
||||
segments=[Segment(40, 299, "Bien"), Segment(299, 329, "venue dans")],
|
||||
model_path="",
|
||||
def transcription(
|
||||
self, transcription_dao, transcription_segment_dao
|
||||
) -> Transcription:
|
||||
id = uuid.uuid4()
|
||||
transcription_dao.insert(
|
||||
Transcription(
|
||||
id=str(id),
|
||||
status="completed",
|
||||
file="testdata/whisper-french.mp3",
|
||||
task=Task.TRANSCRIBE.value,
|
||||
model_type=ModelType.WHISPER.value,
|
||||
whisper_model_size=WhisperModelSize.SMALL.value,
|
||||
)
|
||||
)
|
||||
transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", str(id)))
|
||||
transcription_segment_dao.insert(
|
||||
TranscriptionSegment(299, 329, "venue dans", str(id))
|
||||
)
|
||||
|
||||
def test_should_display_segments(self, qtbot: QtBot, task):
|
||||
widget = TranscriptionViewerWidget(
|
||||
transcription_task=task, open_transcription_output=False
|
||||
)
|
||||
return transcription_dao.find_by_id(str(id))
|
||||
|
||||
def test_should_display_segments(self, qtbot: QtBot, transcription):
|
||||
widget = TranscriptionViewerWidget(transcription)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
assert widget.windowTitle() == "whisper-french.mp3"
|
||||
|
@ -44,37 +50,23 @@ class TestTranscriptionViewerWidget:
|
|||
editor = widget.findChild(TranscriptionSegmentsEditorWidget)
|
||||
assert isinstance(editor, TranscriptionSegmentsEditorWidget)
|
||||
|
||||
assert editor.item(0, 0).text() == "00:00:00.040"
|
||||
assert editor.item(0, 1).text() == "00:00:00.299"
|
||||
assert editor.item(0, 2).text() == "Bien"
|
||||
assert editor.model().index(0, 1).data() == 299
|
||||
assert editor.model().index(0, 2).data() == 40
|
||||
assert editor.model().index(0, 3).data() == "Bien"
|
||||
|
||||
def test_should_update_segment_text(self, qtbot, task):
|
||||
widget = TranscriptionViewerWidget(
|
||||
transcription_task=task, open_transcription_output=False
|
||||
)
|
||||
def test_should_update_segment_text(self, qtbot, transcription):
|
||||
widget = TranscriptionViewerWidget(transcription)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
editor = widget.findChild(TranscriptionSegmentsEditorWidget)
|
||||
assert isinstance(editor, TranscriptionSegmentsEditorWidget)
|
||||
|
||||
# Change text
|
||||
editor.item(0, 2).setText("Biens")
|
||||
assert task.segments[0].text == "Biens"
|
||||
editor.model().setData(editor.model().index(0, 3), "Biens")
|
||||
|
||||
# Undo
|
||||
toolbar = widget.findChild(QToolBar)
|
||||
undo_action, redo_action = toolbar.actions()
|
||||
|
||||
undo_action.trigger()
|
||||
assert task.segments[0].text == "Bien"
|
||||
|
||||
redo_action.trigger()
|
||||
assert task.segments[0].text == "Biens"
|
||||
|
||||
def test_should_export_segments(self, tmp_path: pathlib.Path, qtbot: QtBot, task):
|
||||
widget = TranscriptionViewerWidget(
|
||||
transcription_task=task, open_transcription_output=False
|
||||
)
|
||||
def test_should_export_segments(
|
||||
self, tmp_path: pathlib.Path, qtbot: QtBot, transcription
|
||||
):
|
||||
widget = TranscriptionViewerWidget(transcription)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
export_button = widget.findChild(QPushButton)
|
||||
|
@ -87,5 +79,5 @@ class TestTranscriptionViewerWidget:
|
|||
save_file_name_mock.return_value = (str(output_file_path), "")
|
||||
export_button.menu().actions()[0].trigger()
|
||||
|
||||
output_file = open(output_file_path, "r", encoding="utf-8")
|
||||
assert "Bien\nvenue dans" in output_file.read()
|
||||
with open(output_file_path, encoding="utf-8") as output_file:
|
||||
assert "Bien\nvenue dans" in output_file.read()
|
||||
|
|
Loading…
Reference in a new issue