feat: save transcriptions to sqlite (#682)

This commit is contained in:
Chidi Williams 2024-03-14 01:51:06 +00:00 committed by GitHub
parent dfac983f13
commit ae5af308b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 2039 additions and 1343 deletions

View file

@ -28,9 +28,9 @@ clean:
rm -f buzz/whisper_cpp.py
rm -rf dist/* || true
COVERAGE_THRESHOLD := 76
COVERAGE_THRESHOLD := 77
ifeq ($(UNAME_S),Linux)
COVERAGE_THRESHOLD := 71
COVERAGE_THRESHOLD := 72
endif
test: buzz/whisper_cpp.py translation_mo

View file

@ -6,7 +6,7 @@ import platform
import sys
from typing import TextIO
from appdirs import user_log_dir
from platformdirs import user_log_dir
# Check for segfaults if not running in frozen mode
if getattr(sys, "frozen", False) is False:
@ -57,7 +57,14 @@ def main():
from buzz.cli import parse_command_line
from buzz.widgets.application import Application
from buzz.db.dao.transcription_dao import TranscriptionDAO
from buzz.db.dao.transcription_segment_dao import TranscriptionSegmentDAO
from buzz.db.service.transcription_service import TranscriptionService
from buzz.db.db import setup_app_db
app = Application()
db = setup_app_db()
app = Application(
TranscriptionService(TranscriptionDAO(db), TranscriptionSegmentDAO(db))
)
parse_command_line(app)
sys.exit(app.exec())

0
buzz/db/__init__.py Normal file
View file

0
buzz/db/dao/__init__.py Normal file
View file

53
buzz/db/dao/dao.py Normal file
View file

@ -0,0 +1,53 @@
# Adapted from https://github.com/zhiyiYo/Groove
from abc import ABC
from typing import TypeVar, Generic, Any, Type
from PyQt6.QtSql import QSqlDatabase, QSqlQuery, QSqlRecord
from buzz.db.entity.entity import Entity
T = TypeVar("T", bound=Entity)
class DAO(ABC, Generic[T]):
entity: Type[T]
def __init__(self, table: str, db: QSqlDatabase):
self.db = db
self.table = table
def insert(self, record: T):
query = self._create_query()
keys = record.__dict__.keys()
query.prepare(
f"""
INSERT INTO {self.table} ({", ".join(keys)})
VALUES ({", ".join([f":{key}" for key in keys])})
"""
)
for key, value in record.__dict__.items():
query.bindValue(f":{key}", value)
if not query.exec():
raise Exception(query.lastError().text())
def find_by_id(self, id: Any) -> T | None:
query = self._create_query()
query.prepare(f"SELECT * FROM {self.table} WHERE id = :id")
query.bindValue(":id", id)
return self._execute(query)
def to_entity(self, record: QSqlRecord) -> T:
entity = self.entity()
for i in range(record.count()):
setattr(entity, record.fieldName(i), record.value(i))
return entity
def _execute(self, query: QSqlQuery) -> T | None:
if not query.exec():
raise Exception(query.lastError().text())
if not query.first():
return None
return self.to_entity(query.record())
def _create_query(self):
return QSqlQuery(self.db)

View file

@ -0,0 +1,159 @@
from datetime import datetime
from uuid import UUID
from PyQt6.QtSql import QSqlDatabase
from buzz.db.dao.dao import DAO
from buzz.db.entity.transcription import Transcription
from buzz.transcriber.transcriber import FileTranscriptionTask
class TranscriptionDAO(DAO[Transcription]):
entity = Transcription
def __init__(self, db: QSqlDatabase):
super().__init__("transcription", db)
def create_transcription(self, task: FileTranscriptionTask):
query = self._create_query()
query.prepare(
"""
INSERT INTO transcription (
id,
export_formats,
file,
output_folder,
language,
model_type,
source,
status,
task,
time_queued,
url,
whisper_model_size
) VALUES (
:id,
:export_formats,
:file,
:output_folder,
:language,
:model_type,
:source,
:status,
:task,
:time_queued,
:url,
:whisper_model_size
)
"""
)
query.bindValue(":id", str(task.uid))
query.bindValue(
":export_formats",
", ".join(
[
output_format.value
for output_format in task.file_transcription_options.output_formats
]
),
)
query.bindValue(":file", task.file_path)
query.bindValue(":output_folder", task.output_directory)
query.bindValue(":language", task.transcription_options.language)
query.bindValue(
":model_type", task.transcription_options.model.model_type.value
)
query.bindValue(":source", task.source.value)
query.bindValue(":status", FileTranscriptionTask.Status.QUEUED.value)
query.bindValue(":task", task.transcription_options.task.value)
query.bindValue(":time_queued", datetime.now().isoformat())
query.bindValue(":url", task.url)
query.bindValue(
":whisper_model_size",
task.transcription_options.model.whisper_model_size.value
if task.transcription_options.model.whisper_model_size
else None,
)
if not query.exec():
raise Exception(query.lastError().text())
def update_transcription_as_started(self, id: UUID):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, time_started = :time_started
WHERE id = :id
"""
)
query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.IN_PROGRESS.value)
query.bindValue(":time_started", datetime.now().isoformat())
if not query.exec():
raise Exception(query.lastError().text())
def update_transcription_as_failed(self, id: UUID, error: str):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, time_ended = :time_ended, error_message = :error_message
WHERE id = :id
"""
)
query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.FAILED.value)
query.bindValue(":time_ended", datetime.now().isoformat())
query.bindValue(":error_message", error)
if not query.exec():
raise Exception(query.lastError().text())
def update_transcription_as_canceled(self, id: UUID):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, time_ended = :time_ended
WHERE id = :id
"""
)
query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.CANCELED.value)
query.bindValue(":time_ended", datetime.now().isoformat())
if not query.exec():
raise Exception(query.lastError().text())
def update_transcription_progress(self, id: UUID, progress: float):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, progress = :progress
WHERE id = :id
"""
)
query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.IN_PROGRESS.value)
query.bindValue(":progress", progress)
if not query.exec():
raise Exception(query.lastError().text())
def update_transcription_as_completed(self, id: UUID):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, time_ended = :time_ended
WHERE id = :id
"""
)
query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.COMPLETED.value)
query.bindValue(":time_ended", datetime.now().isoformat())
if not query.exec():
raise Exception(query.lastError().text())

View file

@ -0,0 +1,11 @@
from PyQt6.QtSql import QSqlDatabase
from buzz.db.dao.dao import DAO
from buzz.db.entity.transcription_segment import TranscriptionSegment
class TranscriptionSegmentDAO(DAO[TranscriptionSegment]):
entity = TranscriptionSegment
def __init__(self, db: QSqlDatabase):
super().__init__("transcription_segment", db)

39
buzz/db/db.py Normal file
View file

@ -0,0 +1,39 @@
import logging
import os
import sqlite3
import tempfile
from PyQt6.QtSql import QSqlDatabase
from platformdirs import user_data_dir
from buzz.db.helpers import (
run_sqlite_migrations,
copy_transcriptions_from_json_to_sqlite,
mark_in_progress_and_queued_transcriptions_as_canceled,
)
APP_DB_PATH = os.path.join(user_data_dir("Buzz"), "Buzz.sqlite")
def setup_app_db() -> QSqlDatabase:
return _setup_db(APP_DB_PATH)
def setup_test_db() -> QSqlDatabase:
return _setup_db(tempfile.mktemp())
def _setup_db(path: str) -> QSqlDatabase:
# Run migrations
db = sqlite3.connect(path)
run_sqlite_migrations(db)
copy_transcriptions_from_json_to_sqlite(db)
mark_in_progress_and_queued_transcriptions_as_canceled(db)
db.close()
db = QSqlDatabase.addDatabase("QSQLITE")
db.setDatabaseName(path)
if not db.open():
raise RuntimeError(f"Failed to open database connection: {db.databaseName()}")
logging.debug("Database connection opened: %s", db.databaseName())
return db

View file

12
buzz/db/entity/entity.py Normal file
View file

@ -0,0 +1,12 @@
from abc import ABC
from PyQt6.QtSql import QSqlRecord
class Entity(ABC):
@classmethod
def from_record(cls, record: QSqlRecord):
entity = cls()
for i in range(record.count()):
setattr(entity, record.fieldName(i), record.value(i))
return entity

View file

@ -0,0 +1,54 @@
import datetime
import os
import uuid
from dataclasses import dataclass, field
from buzz.db.entity.entity import Entity
from buzz.model_loader import ModelType
from buzz.settings.settings import Settings
from buzz.transcriber.transcriber import OutputFormat, Task, FileTranscriptionTask
@dataclass
class Transcription(Entity):
status: str = FileTranscriptionTask.Status.QUEUED.value
task: str = Task.TRANSCRIBE.value
model_type: str = ModelType.WHISPER.value
whisper_model_size: str | None = None
language: str | None = None
id: str = field(default_factory=lambda: str(uuid.uuid4()))
error_message: str | None = None
file: str | None = None
time_queued: str = datetime.datetime.now().isoformat()
@property
def id_as_uuid(self):
return uuid.UUID(hex=self.id)
@property
def status_as_status(self):
return FileTranscriptionTask.Status(self.status)
def get_output_file_path(
self,
output_format: OutputFormat,
output_directory: str | None = None,
):
input_file_name = os.path.splitext(os.path.basename(self.file))[0]
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")
export_file_name_template = Settings().get_default_export_file_template()
output_file_name = (
export_file_name_template.replace("{{ input_file_name }}", input_file_name)
.replace("{{ task }}", self.task)
.replace("{{ language }}", self.language or "")
.replace("{{ model_type }}", self.model_type)
.replace("{{ model_size }}", self.whisper_model_size or "")
.replace("{{ date_time }}", date_time_now)
+ f".{output_format.value}"
)
output_directory = output_directory or os.path.dirname(self.file)
return os.path.join(output_directory, output_file_name)

View file

@ -0,0 +1,11 @@
from dataclasses import dataclass
from buzz.db.entity.entity import Entity
@dataclass
class TranscriptionSegment(Entity):
start_time: int
end_time: int
text: str
transcription_id: str

84
buzz/db/helpers.py Normal file
View file

@ -0,0 +1,84 @@
import os
from datetime import datetime
from sqlite3 import Connection
from buzz.cache import TasksCache
from buzz.db.migrator import dumb_migrate_db
def copy_transcriptions_from_json_to_sqlite(conn: Connection):
cache = TasksCache()
if os.path.exists(cache.tasks_list_file_path):
tasks = cache.load()
cursor = conn.cursor()
for task in tasks:
cursor.execute(
"""
INSERT INTO transcription (id, error_message, export_formats, file, output_folder, progress, language, model_type, source, status, task, time_ended, time_queued, time_started, url, whisper_model_size)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
RETURNING id;
""",
(
str(task.uid),
task.error,
", ".join(
[
format.value
for format in task.file_transcription_options.output_formats
]
),
task.file_path,
task.output_directory,
task.fraction_completed,
task.transcription_options.language,
task.transcription_options.model.model_type.value,
task.source.value,
task.status.value,
task.transcription_options.task.value,
task.completed_at,
task.queued_at,
task.started_at,
task.url,
task.transcription_options.model.whisper_model_size.value
if task.transcription_options.model.whisper_model_size
else None,
),
)
transcription_id = cursor.fetchone()[0]
for segment in task.segments:
cursor.execute(
"""
INSERT INTO transcription_segment (end_time, start_time, text, transcription_id)
VALUES (?, ?, ?, ?);
""",
(
segment.end,
segment.start,
segment.text,
transcription_id,
),
)
# os.remove(cache.tasks_list_file_path)
conn.commit()
def run_sqlite_migrations(db: Connection):
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
with open(schema_path) as schema_file:
schema = schema_file.read()
dumb_migrate_db(db=db, schema=schema)
def mark_in_progress_and_queued_transcriptions_as_canceled(conn: Connection):
cursor = conn.cursor()
cursor.execute(
"""
UPDATE transcription
SET status = 'canceled', time_ended = ?
WHERE status = 'in_progress' OR status = 'queued';
""",
(datetime.now().isoformat(),),
)
conn.commit()

284
buzz/db/migrator.py Normal file
View file

@ -0,0 +1,284 @@
# coding: utf-8
# https://gist.github.com/simonw/664b4b0851c1899dc55e1fb655181037
"""Simple declarative schema migration for SQLite.
See <https://david.rothlis.net/declarative-schema-migration-for-sqlite>.
Author: William Manley <will@stb-tester.com>.
Copyright © 2019-2022 Stb-tester.com Ltd.
License: MIT.
"""
import logging
import re
import sqlite3
from textwrap import dedent
def dumb_migrate_db(db, schema, allow_deletions=False):
"""
Migrates a database to the new schema given by the SQL text `schema`
preserving the data. We create any table that exists in schema, delete any
old table that is no longer used and add/remove columns and indices as
necessary.
Under this scheme there are a set of changes that we can make to the schema
and this script will handle it fine:
1. Adding a new table
2. Adding, deleting or modifying an index
3. Adding a column to an existing table as long as the new column can be
NULL or has a DEFAULT value specified.
4. Changing a column to remove NULL or DEFAULT as long as all values in the
database are not NULL
5. Changing the type of a column
6. Changing the user_version
In addition this function is capable of:
1. Deleting tables
2. Deleting columns from tables
But only if allow_deletions=True. If the new schema requires a column/table
to be deleted and allow_deletions=False this function will raise
`RuntimeError`.
Note: When this function is called a transaction must not be held open on
db. A transaction will be used internally. If you wish to perform
additional migration steps as part of a migration use DBMigrator directly.
Any internally generated rowid columns by SQLite may change values by this
migration.
"""
with DBMigrator(db, schema, allow_deletions) as migrator:
migrator.migrate()
return bool(migrator.n_changes)
class DBMigrator:
def __init__(self, db, schema, allow_deletions=False):
self.db = db
self.schema = schema
self.allow_deletions = allow_deletions
self.pristine = sqlite3.connect(":memory:")
self.pristine.executescript(schema)
self.n_changes = 0
self.orig_foreign_keys = None
def log_execute(self, msg, sql, args=None):
# It's important to log any changes we're making to the database for
# forensics later
msg_tmpl = "Database migration: %s with SQL:\n%s"
msg_argv = (msg, _left_pad(dedent(sql)))
if args:
msg_tmpl += " args = %r"
msg_argv += (args,)
else:
args = []
logging.info(msg_tmpl, *msg_argv)
self.db.execute(sql, args)
self.n_changes += 1
def __enter__(self):
self.orig_foreign_keys = self.db.execute("PRAGMA foreign_keys").fetchone()[0]
if self.orig_foreign_keys:
self.log_execute(
"Disable foreign keys temporarily for migration",
"PRAGMA foreign_keys = OFF",
)
# This doesn't count as a change because we'll undo it at the end
self.n_changes = 0
self.db.__enter__()
self.db.execute("BEGIN")
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.db.__exit__(exc_type, exc_value, exc_tb)
if exc_value is None:
# The SQLite docs say:
#
# > This pragma is a no-op within a transaction; foreign key
# > constraint enforcement may only be enabled or disabled when
# > there is no pending BEGIN or SAVEPOINT.
old_changes = self.n_changes
new_val = self._migrate_pragma("foreign_keys")
if new_val == self.orig_foreign_keys:
self.n_changes = old_changes
# SQLite docs say:
#
# > A VACUUM will fail if there is an open transaction on the database
# > connection that is attempting to run the VACUUM.
if self.n_changes:
self.db.execute("VACUUM")
else:
if self.orig_foreign_keys:
self.log_execute(
"Re-enable foreign keys after migration", "PRAGMA foreign_keys = ON"
)
def migrate(self):
# In CI the database schema may be changing all the time. This checks
# the current db and if it doesn't match database.sql we will
# modify it so it does match where possible.
pristine_tables = dict(
self.pristine.execute(
"""\
SELECT name, sql FROM sqlite_master
WHERE type = \"table\" AND name != \"sqlite_sequence\""""
).fetchall()
)
pristine_indices = dict(
self.pristine.execute(
"""\
SELECT name, sql FROM sqlite_master
WHERE type = \"index\""""
).fetchall()
)
tables = dict(
self.db.execute(
"""\
SELECT name, sql FROM sqlite_master
WHERE type = \"table\" AND name != \"sqlite_sequence\""""
).fetchall()
)
new_tables = set(pristine_tables.keys()) - set(tables.keys())
removed_tables = set(tables.keys()) - set(pristine_tables.keys())
if removed_tables and not self.allow_deletions:
raise RuntimeError(
"Database migration: Refusing to delete tables %r" % removed_tables
)
modified_tables = set(
name
for name, sql in pristine_tables.items()
if normalise_sql(tables.get(name, "")) != normalise_sql(sql)
)
# This PRAGMA is automatically disabled when the db is committed
self.db.execute("PRAGMA defer_foreign_keys = TRUE")
# New and removed tables are easy:
for tbl_name in new_tables:
self.log_execute("Create table %s" % tbl_name, pristine_tables[tbl_name])
for tbl_name in removed_tables:
self.log_execute("Drop table %s" % tbl_name, "DROP TABLE %s" % tbl_name)
for tbl_name in modified_tables:
# The SQLite documentation insists that we create the new table and
# rename it over the old rather than moving the old out of the way
# and then creating the new
create_table_sql = pristine_tables[tbl_name]
create_table_sql = re.sub(
r"\b%s\b" % re.escape(tbl_name),
tbl_name + "_migration_new",
create_table_sql,
)
self.log_execute(
"Columns change: Create table %s with updated schema" % tbl_name,
create_table_sql,
)
cols = set(
[x[1] for x in self.db.execute("PRAGMA table_info(%s)" % tbl_name)]
)
pristine_cols = set(
[
x[1]
for x in self.pristine.execute("PRAGMA table_info(%s)" % tbl_name)
]
)
removed_columns = cols - pristine_cols
if not self.allow_deletions and removed_columns:
logging.warning(
"Database migration: Refusing to remove columns %r from "
"table %s. Current cols are %r attempting migration to %r",
removed_columns,
tbl_name,
cols,
pristine_cols,
)
raise RuntimeError(
"Database migration: Refusing to remove columns %r from "
"table %s" % (removed_columns, tbl_name)
)
logging.info("cols: %s, pristine_cols: %s", cols, pristine_cols)
self.log_execute(
"Migrate data for table %s" % tbl_name,
"""\
INSERT INTO {tbl_name}_migration_new ({common})
SELECT {common} FROM {tbl_name}""".format(
tbl_name=tbl_name,
common=", ".join(cols.intersection(pristine_cols)),
),
)
# Don't need the old table any more
self.log_execute(
"Drop old table %s now data has been migrated" % tbl_name,
"DROP TABLE %s" % tbl_name,
)
self.log_execute(
"Columns change: Move new table %s over old" % tbl_name,
"ALTER TABLE %s_migration_new RENAME TO %s" % (tbl_name, tbl_name),
)
# Migrate the indices
indices = dict(
self.db.execute(
"""\
SELECT name, sql FROM sqlite_master
WHERE type = \"index\""""
).fetchall()
)
for name in set(indices.keys()) - set(pristine_indices.keys()):
self.log_execute(
"Dropping obsolete index %s" % name, "DROP INDEX %s" % name
)
for name, sql in pristine_indices.items():
if name not in indices:
self.log_execute("Creating new index %s" % name, sql)
elif sql != indices[name]:
self.log_execute(
"Index %s changed: Dropping old version" % name,
"DROP INDEX %s" % name,
)
self.log_execute(
"Index %s changed: Creating updated version in its place" % name,
sql,
)
self._migrate_pragma("user_version")
if self.pristine.execute("PRAGMA foreign_keys").fetchone()[0]:
if self.db.execute("PRAGMA foreign_key_check").fetchall():
raise RuntimeError("Database migration: Would fail foreign_key_check")
def _migrate_pragma(self, pragma):
pristine_val = self.pristine.execute("PRAGMA %s" % pragma).fetchone()[0]
val = self.db.execute("PRAGMA %s" % pragma).fetchone()[0]
if val != pristine_val:
self.log_execute(
"Set %s to %i from %i" % (pragma, pristine_val, val),
"PRAGMA %s = %i" % (pragma, pristine_val),
)
return pristine_val
def _left_pad(text, indent=" "):
"""Maybe I can find a package in pypi for this?"""
return "\n".join(indent + line for line in text.split("\n"))
def normalise_sql(sql):
# Remove comments:
sql = re.sub(r"--[^\n]*\n", "", sql)
# Normalise whitespace:
sql = re.sub(r"\s+", " ", sql)
sql = re.sub(r" *([(),]) *", r"\1", sql)
# Remove unnecessary quotes
sql = re.sub(r'"(\w+)"', r"\1", sql)
return sql.strip()

28
buzz/db/schema.sql Normal file
View file

@ -0,0 +1,28 @@
CREATE TABLE transcription (
id TEXT PRIMARY KEY,
error_message TEXT,
export_formats TEXT,
file TEXT,
output_folder TEXT,
progress DOUBLE PRECISION DEFAULT 0.0,
language TEXT,
model_type TEXT,
source TEXT,
status TEXT,
task TEXT,
time_ended TIMESTAMP,
time_queued TIMESTAMP NOT NULL,
time_started TIMESTAMP,
url TEXT,
whisper_model_size TEXT
);
CREATE TABLE transcription_segment (
id INTEGER PRIMARY KEY,
end_time INT DEFAULT 0,
start_time INT DEFAULT 0,
text TEXT NOT NULL,
transcription_id TEXT,
FOREIGN KEY (transcription_id) REFERENCES transcription(id) ON DELETE CASCADE
);
CREATE INDEX idx_transcription_id ON transcription_segment(transcription_id);

View file

View file

@ -0,0 +1,44 @@
from typing import List
from uuid import UUID
from buzz.db.dao.transcription_dao import TranscriptionDAO
from buzz.db.dao.transcription_segment_dao import TranscriptionSegmentDAO
from buzz.db.entity.transcription_segment import TranscriptionSegment
from buzz.transcriber.transcriber import Segment
class TranscriptionService:
def __init__(
self,
transcription_dao: TranscriptionDAO,
transcription_segment_dao: TranscriptionSegmentDAO,
):
self.transcription_dao = transcription_dao
self.transcription_segment_dao = transcription_segment_dao
def create_transcription(self, task):
self.transcription_dao.create_transcription(task)
def update_transcription_as_started(self, id: UUID):
self.transcription_dao.update_transcription_as_started(id)
def update_transcription_as_failed(self, id: UUID, error: str):
self.transcription_dao.update_transcription_as_failed(id, error)
def update_transcription_as_canceled(self, id: UUID):
self.transcription_dao.update_transcription_as_canceled(id)
def update_transcription_progress(self, id: UUID, progress: float):
self.transcription_dao.update_transcription_progress(id, progress)
def update_transcription_as_completed(self, id: UUID, segments: List[Segment]):
self.transcription_dao.update_transcription_as_completed(id)
for segment in segments:
self.transcription_segment_dao.insert(
TranscriptionSegment(
start_time=segment.start,
end_time=segment.end,
text=segment.text,
transcription_id=str(id),
)
)

View file

@ -1,8 +1,8 @@
import datetime
import logging
import multiprocessing
import queue
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Set
from uuid import UUID
from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot
@ -21,13 +21,19 @@ class FileTranscriberQueueWorker(QObject):
current_task: Optional[FileTranscriptionTask] = None
current_transcriber: Optional[FileTranscriber] = None
current_transcriber_thread: Optional[QThread] = None
task_updated = pyqtSignal(FileTranscriptionTask)
task_started = pyqtSignal(FileTranscriptionTask)
task_progress = pyqtSignal(FileTranscriptionTask, float)
task_download_progress = pyqtSignal(FileTranscriptionTask, float)
task_completed = pyqtSignal(FileTranscriptionTask, list)
task_error = pyqtSignal(FileTranscriptionTask, str)
completed = pyqtSignal()
def __init__(self, parent: Optional[QObject] = None):
super().__init__(parent)
self.tasks_queue = queue.Queue()
self.canceled_tasks = set()
self.canceled_tasks: Set[UUID] = set()
@pyqtSlot()
def run(self):
@ -42,7 +48,7 @@ class FileTranscriberQueueWorker(QObject):
self.completed.emit()
return
if self.current_task.id in self.canceled_tasks:
if self.current_task.uid in self.canceled_tasks:
continue
break
@ -91,53 +97,41 @@ class FileTranscriberQueueWorker(QObject):
self.current_transcriber.error.connect(self.run)
self.current_transcriber.completed.connect(self.run)
self.current_task.started_at = datetime.datetime.now()
self.task_started.emit(self.current_task)
self.current_transcriber_thread.start()
def add_task(self, task: FileTranscriptionTask):
if task.queued_at is None:
task.queued_at = datetime.datetime.now()
self.tasks_queue.put(task)
task.status = FileTranscriptionTask.Status.QUEUED
self.task_updated.emit(task)
def cancel_task(self, task_id: int):
def cancel_task(self, task_id: UUID):
self.canceled_tasks.add(task_id)
if self.current_task.id == task_id:
if self.current_task.uid == task_id:
if self.current_transcriber is not None:
self.current_transcriber.stop()
def on_task_error(self, error: str):
if (
self.current_task is not None
and self.current_task.id not in self.canceled_tasks
and self.current_task.uid not in self.canceled_tasks
):
self.current_task.status = FileTranscriptionTask.Status.FAILED
self.current_task.error = error
self.task_updated.emit(self.current_task)
self.task_error.emit(self.current_task, error)
@pyqtSlot(tuple)
def on_task_progress(self, progress: Tuple[int, int]):
if self.current_task is not None:
self.current_task.status = FileTranscriptionTask.Status.IN_PROGRESS
self.current_task.fraction_completed = progress[0] / progress[1]
self.task_updated.emit(self.current_task)
self.task_progress.emit(self.current_task, progress[0] / progress[1])
def on_task_download_progress(self, fraction_downloaded: float):
if self.current_task is not None:
self.current_task.status = FileTranscriptionTask.Status.IN_PROGRESS
self.current_task.fraction_downloaded = fraction_downloaded
self.task_updated.emit(self.current_task)
self.task_download_progress.emit(self.current_task, fraction_downloaded)
@pyqtSlot(list)
def on_task_completed(self, segments: List[Segment]):
if self.current_task is not None:
self.current_task.status = FileTranscriptionTask.Status.COMPLETED
self.current_task.segments = segments
self.current_task.completed_at = datetime.datetime.now()
self.task_updated.emit(self.current_task)
self.task_completed.emit(self.current_task, segments)
def stop(self):
self.tasks_queue.put(None)

View file

@ -68,7 +68,12 @@ class FileTranscriber(QObject):
output_format
) in self.transcription_task.file_transcription_options.output_formats:
default_path = get_output_file_path(
task=self.transcription_task, output_format=output_format
file_path=self.transcription_task.file_path,
output_format=output_format,
language=self.transcription_task.transcription_options.language,
output_directory=self.transcription_task.output_directory,
model=self.transcription_task.transcription_options.model,
task=self.transcription_task.transcription_options.task,
)
write_output(

View file

@ -1,6 +1,7 @@
import datetime
import enum
import os
import uuid
from dataclasses import dataclass, field
from random import randint
from typing import List, Optional, Tuple, Set
@ -174,7 +175,9 @@ class FileTranscriptionTask:
transcription_options: TranscriptionOptions
file_transcription_options: FileTranscriptionOptions
model_path: str
# deprecated: use uid
id: int = field(default_factory=lambda: randint(0, 100_000_000))
uid: uuid.UUID = field(default_factory=uuid.uuid4)
segments: List[Segment] = field(default_factory=list)
status: Optional[Status] = None
fraction_completed = 0.0
@ -188,38 +191,6 @@ class FileTranscriptionTask:
url: Optional[str] = None
fraction_downloaded: float = 0.0
def status_text(self) -> str:
match self.status:
case FileTranscriptionTask.Status.IN_PROGRESS:
if self.fraction_downloaded > 0 and self.fraction_completed == 0:
return f'{_("Downloading")} ({self.fraction_downloaded :.0%})'
return f'{_("In Progress")} ({self.fraction_completed :.0%})'
case FileTranscriptionTask.Status.COMPLETED:
status = _("Completed")
if self.started_at is not None and self.completed_at is not None:
status += f" ({self.format_timedelta(self.completed_at - self.started_at)})"
return status
case FileTranscriptionTask.Status.FAILED:
return f'{_("Failed")} ({self.error})'
case FileTranscriptionTask.Status.CANCELED:
return _("Canceled")
case FileTranscriptionTask.Status.QUEUED:
return _("Queued")
case _:
return ""
@staticmethod
def format_timedelta(delta: datetime.timedelta):
mm, ss = divmod(delta.seconds, 60)
result = f"{ss}s"
if mm == 0:
return result
hh, mm = divmod(mm, 60)
result = f"{mm}m {result}"
if hh == 0:
return result
return f"{hh}h {result}"
class OutputFormat(enum.Enum):
TXT = "txt"
@ -236,11 +207,15 @@ Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)"
def get_output_file_path(
task: FileTranscriptionTask,
file_path: str,
task: Task,
language: Optional[str],
model: TranscriptionModel,
output_format: OutputFormat,
output_directory: str | None = None,
export_file_name_template: str | None = None,
):
input_file_name = os.path.splitext(os.path.basename(task.file_path))[0]
input_file_name = os.path.splitext(os.path.basename(file_path))[0]
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")
export_file_name_template = (
@ -251,18 +226,18 @@ def get_output_file_path(
output_file_name = (
export_file_name_template.replace("{{ input_file_name }}", input_file_name)
.replace("{{ task }}", task.transcription_options.task.value)
.replace("{{ language }}", task.transcription_options.language or "")
.replace("{{ model_type }}", task.transcription_options.model.model_type.value)
.replace("{{ task }}", task.value)
.replace("{{ language }}", language or "")
.replace("{{ model_type }}", model.model_type.value)
.replace(
"{{ model_size }}",
task.transcription_options.model.whisper_model_size.value
if task.transcription_options.model.whisper_model_size is not None
model.whisper_model_size.value
if model.whisper_model_size is not None
else "",
)
.replace("{{ date_time }}", date_time_now)
+ f".{output_format.value}"
)
output_directory = task.output_directory or os.path.dirname(task.file_path)
output_directory = output_directory or os.path.dirname(file_path)
return os.path.join(output_directory, output_file_name)

View file

@ -3,6 +3,7 @@ import sys
from PyQt6.QtWidgets import QApplication
from buzz.__version__ import VERSION
from buzz.db.service.transcription_service import TranscriptionService
from buzz.settings.settings import APP_NAME
from buzz.transcriber.transcriber import FileTranscriptionTask
from buzz.widgets.main_window import MainWindow
@ -11,7 +12,7 @@ from buzz.widgets.main_window import MainWindow
class Application(QApplication):
window: MainWindow
def __init__(self) -> None:
def __init__(self, transcription_service: TranscriptionService) -> None:
super().__init__(sys.argv)
self.setApplicationName(APP_NAME)
@ -20,7 +21,7 @@ class Application(QApplication):
if sys.platform == "darwin":
self.setStyle("Fusion")
self.window = MainWindow()
self.window = MainWindow(transcription_service)
self.window.show()
def add_task(self, task: FileTranscriptionTask):

View file

@ -1,4 +1,5 @@
from typing import Dict, Tuple, List, Optional
import logging
from typing import Tuple, List, Optional
from PyQt6 import QtGui
from PyQt6.QtCore import (
@ -13,7 +14,8 @@ from PyQt6.QtWidgets import (
QFileDialog,
)
from buzz.cache import TasksCache
from buzz.db.entity.transcription import Transcription
from buzz.db.service.transcription_service import TranscriptionService
from buzz.file_transcriber_queue_worker import FileTranscriberQueueWorker
from buzz.locale import _
from buzz.settings.settings import APP_NAME, Settings
@ -24,6 +26,7 @@ from buzz.transcriber.transcriber import (
TranscriptionOptions,
FileTranscriptionOptions,
SUPPORTED_AUDIO_FORMATS,
Segment,
)
from buzz.widgets.icon import BUZZ_ICON_PATH
from buzz.widgets.import_url_dialog import ImportURLDialog
@ -34,7 +37,9 @@ from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidg
from buzz.widgets.transcription_task_folder_watcher import (
TranscriptionTaskFolderWatcher,
)
from buzz.widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
from buzz.widgets.transcription_tasks_table_widget import (
TranscriptionTasksTableWidget,
)
from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
TranscriptionViewerWidget,
)
@ -42,9 +47,8 @@ from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
class MainWindow(QMainWindow):
table_widget: TranscriptionTasksTableWidget
tasks: Dict[int, "FileTranscriptionTask"]
def __init__(self, tasks_cache=TasksCache()):
def __init__(self, transcription_service: TranscriptionService):
super().__init__(flags=Qt.WindowType.Window)
self.setWindowTitle(APP_NAME)
@ -53,14 +57,12 @@ class MainWindow(QMainWindow):
self.setAcceptDrops(True)
self.tasks_cache = tasks_cache
self.settings = Settings()
self.shortcut_settings = ShortcutSettings(settings=self.settings)
self.shortcuts = self.shortcut_settings.load()
self.tasks = {}
self.transcription_service = transcription_service
self.toolbar = MainWindowToolbar(shortcuts=self.shortcuts, parent=self)
self.toolbar.new_transcription_action_triggered.connect(
@ -100,7 +102,9 @@ class MainWindow(QMainWindow):
self.table_widget = TranscriptionTasksTableWidget(self)
self.table_widget.doubleClicked.connect(self.on_table_double_clicked)
self.table_widget.return_clicked.connect(self.open_transcript_viewer)
self.table_widget.itemSelectionChanged.connect(self.on_table_selection_changed)
self.table_widget.selectionModel().selectionChanged.connect(
self.on_table_selection_changed
)
self.setCentralWidget(self.table_widget)
@ -110,19 +114,24 @@ class MainWindow(QMainWindow):
self.transcriber_worker = FileTranscriberQueueWorker()
self.transcriber_worker.moveToThread(self.transcriber_thread)
self.transcriber_worker.task_updated.connect(self.update_task_table_row)
self.transcriber_worker.task_started.connect(self.on_task_started)
self.transcriber_worker.task_progress.connect(self.on_task_progress)
self.transcriber_worker.task_download_progress.connect(
self.on_task_download_progress
)
self.transcriber_worker.task_error.connect(self.on_task_error)
self.transcriber_worker.task_completed.connect(self.on_task_completed)
self.transcriber_worker.completed.connect(self.transcriber_thread.quit)
self.transcriber_thread.started.connect(self.transcriber_worker.run)
self.transcriber_thread.start()
self.load_tasks_from_cache()
self.load_geometry()
self.folder_watcher = TranscriptionTaskFolderWatcher(
tasks=self.tasks,
tasks={},
preferences=self.preferences.folder_watch,
)
self.folder_watcher.task_found.connect(self.add_task)
@ -181,21 +190,6 @@ class MainWindow(QMainWindow):
)
self.add_task(task)
def upsert_task_in_table(self, task: FileTranscriptionTask):
self.table_widget.upsert_task(task)
self.tasks[task.id] = task
def update_task_table_row(self, task: FileTranscriptionTask):
self.upsert_task_in_table(task=task)
self.on_tasks_changed()
@staticmethod
def task_completed_or_errored(task: FileTranscriptionTask):
return (
task.status == FileTranscriptionTask.Status.COMPLETED
or task.status == FileTranscriptionTask.Status.FAILED
)
def on_clear_history_action_triggered(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
if len(selected_rows) == 0:
@ -210,25 +204,18 @@ class MainWindow(QMainWindow):
),
)
if reply == QMessageBox.StandardButton.Yes:
task_ids = [
TranscriptionTasksTableWidget.find_task_id(selected_row)
for selected_row in selected_rows
]
for task_id in task_ids:
self.table_widget.clear_task(task_id)
self.tasks.pop(task_id)
self.on_tasks_changed()
self.table_widget.delete_transcriptions(selected_rows)
def on_stop_transcription_action_triggered(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
for selected_row in selected_rows:
task_id = TranscriptionTasksTableWidget.find_task_id(selected_row)
task = self.tasks[task_id]
task.status = FileTranscriptionTask.Status.CANCELED
self.on_tasks_changed()
self.transcriber_worker.cancel_task(task_id)
self.table_widget.upsert_task(task)
selected_transcriptions = self.table_widget.selected_transcriptions()
for transcription in selected_transcriptions:
transcription_id = transcription.id_as_uuid
self.transcriber_worker.cancel_task(transcription_id)
self.transcription_service.update_transcription_as_canceled(
transcription_id
)
self.table_widget.refresh_row(transcription_id)
self.on_table_selection_changed()
def on_new_transcription_action_triggered(self):
(file_paths, __) = QFileDialog.getOpenFileNames(
@ -266,8 +253,8 @@ class MainWindow(QMainWindow):
def open_transcript_viewer(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
for selected_row in selected_rows:
task_id = TranscriptionTasksTableWidget.find_task_id(selected_row)
self.open_transcription_viewer(task_id)
transcription = self.table_widget.transcription(selected_row)
self.open_transcription_viewer(transcription)
def on_table_selection_changed(self):
self.toolbar.set_open_transcript_action_enabled(
@ -281,7 +268,20 @@ class MainWindow(QMainWindow):
)
def should_enable_open_transcript_action(self):
return self.selected_tasks_have_status([FileTranscriptionTask.Status.COMPLETED])
selected_transcriptions = self.table_widget.selected_transcriptions()
if len(selected_transcriptions) == 0:
return False
return all(
MainWindow.can_open_transcript(transcription)
for transcription in selected_transcriptions
)
@staticmethod
def can_open_transcript(transcription: Transcription) -> bool:
return (
FileTranscriptionTask.Status(transcription.status)
== FileTranscriptionTask.Status.COMPLETED
)
def should_enable_stop_transcription_action(self):
return self.selected_tasks_have_status(
@ -301,63 +301,56 @@ class MainWindow(QMainWindow):
)
def selected_tasks_have_status(self, statuses: List[FileTranscriptionTask.Status]):
selected_rows = self.table_widget.selectionModel().selectedRows()
if len(selected_rows) == 0:
transcriptions = self.table_widget.selected_transcriptions()
if len(transcriptions) == 0:
return False
return all(
[
self.tasks[
TranscriptionTasksTableWidget.find_task_id(selected_row)
].status
in statuses
for selected_row in selected_rows
transcription.status_as_status in statuses
for transcription in transcriptions
]
)
def on_table_double_clicked(self, index: QModelIndex):
task_id = TranscriptionTasksTableWidget.find_task_id(index)
self.open_transcription_viewer(task_id)
def open_transcription_viewer(self, task_id: int):
task = self.tasks[task_id]
if task.status != FileTranscriptionTask.Status.COMPLETED:
transcription = self.table_widget.transcription(index)
if not MainWindow.can_open_transcript(transcription):
return
self.open_transcription_viewer(transcription)
def open_transcription_viewer(self, transcription: Transcription):
transcription_viewer_widget = TranscriptionViewerWidget(
transcription_task=task, parent=self, flags=Qt.WindowType.Window
transcription=transcription, parent=self, flags=Qt.WindowType.Window
)
transcription_viewer_widget.task_changed.connect(self.on_tasks_changed)
transcription_viewer_widget.show()
def add_task(self, task: FileTranscriptionTask):
self.transcription_service.create_transcription(task)
self.table_widget.refresh_all()
self.transcriber_worker.add_task(task)
def load_tasks_from_cache(self):
tasks = self.tasks_cache.load()
for task in tasks:
if (
task.status == FileTranscriptionTask.Status.QUEUED
or task.status == FileTranscriptionTask.Status.IN_PROGRESS
):
task.status = None
self.add_task(task)
else:
self.upsert_task_in_table(task=task)
def on_task_started(self, task: FileTranscriptionTask):
self.transcription_service.update_transcription_as_started(task.uid)
self.table_widget.refresh_row(task.uid)
def save_tasks_to_cache(self):
self.tasks_cache.save(list(self.tasks.values()))
def on_task_progress(self, task: FileTranscriptionTask, progress: float):
self.transcription_service.update_transcription_progress(task.uid, progress)
self.table_widget.refresh_row(task.uid)
def on_tasks_changed(self):
self.toolbar.set_open_transcript_action_enabled(
self.should_enable_open_transcript_action()
)
self.toolbar.set_stop_transcription_action_enabled(
self.should_enable_stop_transcription_action()
)
self.toolbar.set_clear_history_action_enabled(
self.should_enable_clear_history_action()
)
self.save_tasks_to_cache()
def on_task_download_progress(
self, task: FileTranscriptionTask, fraction_downloaded: float
):
# TODO: Save download progress in the database
pass
def on_task_completed(self, task: FileTranscriptionTask, segments: List[Segment]):
self.transcription_service.update_transcription_as_completed(task.uid, segments)
self.table_widget.refresh_row(task.uid)
def on_task_error(self, task: FileTranscriptionTask, error: str):
logging.debug("FAILED!!!!")
self.transcription_service.update_transcription_as_failed(task.uid, error)
self.table_widget.refresh_row(task.uid)
def on_shortcuts_changed(self, shortcuts: dict):
self.shortcuts = shortcuts
@ -374,7 +367,6 @@ class MainWindow(QMainWindow):
self.transcriber_worker.stop()
self.transcriber_thread.quit()
self.transcriber_thread.wait()
self.save_tasks_to_cache()
self.shortcut_settings.save(shortcuts=self.shortcuts)
super().closeEvent(event)

View file

@ -0,0 +1,15 @@
from typing import Callable
from PyQt6.QtSql import QSqlRecord, QSqlTableModel
from PyQt6.QtWidgets import QStyledItemDelegate
class RecordDelegate(QStyledItemDelegate):
def __init__(self, text_getter: Callable[[QSqlRecord], str]):
super().__init__()
self.callback = text_getter
def initStyleOption(self, option, index):
super().initStyleOption(option, index)
model: QSqlTableModel = index.model()
option.text = self.callback(model.record(index.row()))

View file

@ -0,0 +1,25 @@
from uuid import UUID
from PyQt6.QtSql import QSqlRecord
from buzz.model_loader import TranscriptionModel, ModelType, WhisperModelSize
from buzz.transcriber.transcriber import Task
class TranscriptionRecord:
@staticmethod
def id(record: QSqlRecord) -> UUID:
return UUID(hex=record.value("id"))
@staticmethod
def model(record: QSqlRecord) -> TranscriptionModel:
return TranscriptionModel(
model_type=ModelType(record.value("model_type")),
whisper_model_size=WhisperModelSize(record.value("whisper_model_size"))
if record.value("whisper_model_size")
else None,
)
@staticmethod
def task(record: QSqlRecord) -> Task:
return Task(record.value("task"))

View file

@ -15,6 +15,7 @@ class TranscriptionTaskFolderWatcher(QFileSystemWatcher):
preferences: FolderWatchPreferences
task_found = pyqtSignal(FileTranscriptionTask)
# TODO: query db instead of passing tasks
def __init__(
self,
tasks: Dict[int, FileTranscriptionTask],

View file

@ -1,147 +1,203 @@
import enum
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import auto
from typing import Optional, Callable
from typing import Optional, List
from uuid import UUID
from PyQt6 import QtGui
from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex
from PyQt6.QtCore import Qt
from PyQt6.QtCore import pyqtSignal, QModelIndex
from PyQt6.QtSql import QSqlTableModel, QSqlRecord
from PyQt6.QtWidgets import (
QTableWidget,
QWidget,
QAbstractItemView,
QTableWidgetItem,
QMenu,
QTableView,
QAbstractItemView,
QStyledItemDelegate,
)
from buzz.db.entity.transcription import Transcription
from buzz.locale import _
from buzz.settings.settings import Settings
from buzz.transcriber.transcriber import FileTranscriptionTask, humanize_language
from buzz.transcriber.transcriber import FileTranscriptionTask
from buzz.widgets.record_delegate import RecordDelegate
from buzz.widgets.transcription_record import TranscriptionRecord
class Column(enum.Enum):
ID = 0
ERROR_MESSAGE = auto()
EXPORT_FORMATS = auto()
FILE = auto()
OUTPUT_FOLDER = auto()
PROGRESS = auto()
LANGUAGE = auto()
MODEL_TYPE = auto()
SOURCE = auto()
STATUS = auto()
TASK = auto()
TIME_ENDED = auto()
TIME_QUEUED = auto()
TIME_STARTED = auto()
URL = auto()
WHISPER_MODEL_SIZE = auto()
@dataclass
class TableColDef:
class ColDef:
id: str
header: str
column_index: int
value_getter: Callable[[FileTranscriptionTask], str]
column: Column
width: Optional[int] = None
hidden: bool = False
delegate: Optional[QStyledItemDelegate] = None
hidden_toggleable: bool = True
class TranscriptionTasksTableWidget(QTableWidget):
class Column(enum.Enum):
TASK_ID = 0
FILE_NAME = auto()
MODEL = auto()
TASK = auto()
STATUS = auto()
DATE_ADDED = auto()
DATE_COMPLETED = auto()
def format_record_status_text(record: QSqlRecord) -> str:
status = FileTranscriptionTask.Status(record.value("status"))
match status:
case FileTranscriptionTask.Status.IN_PROGRESS:
return f'{_("In Progress")} ({record.value("progress") :.0%})'
case FileTranscriptionTask.Status.COMPLETED:
status = _("Completed")
started_at = record.value("time_started")
completed_at = record.value("time_ended")
if started_at != "" and completed_at != "":
status += f" ({TranscriptionTasksTableWidget.format_timedelta(datetime.fromisoformat(completed_at) - datetime.fromisoformat(started_at))})"
return status
case FileTranscriptionTask.Status.FAILED:
return f'{_("Failed")} ({record.value("error_message")})'
case FileTranscriptionTask.Status.CANCELED:
return _("Canceled")
case FileTranscriptionTask.Status.QUEUED:
return _("Queued")
case _:
return ""
column_definitions = [
ColDef(
id="file_name",
header="File Name / URL",
column=Column.FILE,
width=400,
delegate=RecordDelegate(
text_getter=lambda record: record.value("url")
if record.value("url") != ""
else os.path.basename(record.value("file"))
),
hidden_toggleable=False,
),
ColDef(
id="model",
header="Model",
column=Column.MODEL_TYPE,
width=180,
delegate=RecordDelegate(
text_getter=lambda record: str(TranscriptionRecord.model(record))
),
),
ColDef(
id="task",
header="Task",
column=Column.SOURCE,
width=120,
delegate=RecordDelegate(
text_getter=lambda record: record.value("task").capitalize()
),
),
ColDef(
id="status",
header="Status",
column=Column.STATUS,
width=180,
delegate=RecordDelegate(text_getter=format_record_status_text),
hidden_toggleable=False,
),
ColDef(
id="date_added",
header="Date Added",
column=Column.TIME_QUEUED,
width=180,
delegate=RecordDelegate(
text_getter=lambda record: datetime.fromisoformat(
record.value("time_queued")
).strftime("%Y-%m-%d %H:%M:%S")
),
),
ColDef(
id="date_completed",
header="Date Completed",
column=Column.TIME_ENDED,
width=180,
delegate=RecordDelegate(
text_getter=lambda record: datetime.fromisoformat(
record.value("time_ended")
).strftime("%Y-%m-%d %H:%M:%S")
if record.value("time_ended") != ""
else ""
),
),
]
class TranscriptionTasksTableWidget(QTableView):
return_clicked = pyqtSignal()
def __init__(self, parent: Optional[QWidget] = None):
super().__init__(parent)
self.setRowCount(0)
self.setAlternatingRowColors(True)
self._model = QSqlTableModel()
self._model.setTable("transcription")
self._model.setEditStrategy(QSqlTableModel.EditStrategy.OnManualSubmit)
self._model.setSort(Column.TIME_QUEUED.value, Qt.SortOrder.DescendingOrder)
self.setModel(self._model)
for i in range(self.model().columnCount()):
self.hideColumn(i)
self.settings = Settings()
self.column_definitions = [
TableColDef(
id="id",
header=_("ID"),
column_index=self.Column.TASK_ID.value,
value_getter=lambda task: str(task.id),
width=0,
hidden=True,
hidden_toggleable=False,
),
TableColDef(
id="file_name",
header=_("File Name/URL"),
column_index=self.Column.FILE_NAME.value,
value_getter=lambda task: task.url
if task.url is not None
else os.path.basename(task.file_path),
width=300,
hidden_toggleable=False,
),
TableColDef(
id="model",
header=_("Model"),
column_index=self.Column.MODEL.value,
value_getter=lambda task: str(task.transcription_options.model),
width=180,
hidden=True,
),
TableColDef(
id="task",
header=_("Task"),
column_index=self.Column.TASK.value,
value_getter=lambda task: self.get_task_label(task),
width=120,
hidden=True,
),
TableColDef(
id="status",
header=_("Status"),
column_index=self.Column.STATUS.value,
value_getter=lambda task: task.status_text(),
width=180,
hidden_toggleable=False,
),
TableColDef(
id="date_added",
header=_("Date Added"),
column_index=self.Column.DATE_ADDED.value,
value_getter=lambda task: task.queued_at.strftime("%Y-%m-%d %H:%M:%S")
if task.queued_at is not None
else "",
width=180,
hidden=True,
),
TableColDef(
id="date_completed",
header=_("Date Completed"),
column_index=self.Column.DATE_COMPLETED.value,
value_getter=lambda task: task.completed_at.strftime(
"%Y-%m-%d %H:%M:%S"
)
if task.completed_at is not None
else "",
width=180,
hidden=True,
),
]
self.setColumnCount(len(self.column_definitions))
self.verticalHeader().hide()
self.setHorizontalHeaderLabels(
[definition.header for definition in self.column_definitions]
self.settings.begin_group(
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
)
for definition in self.column_definitions:
for definition in column_definitions:
self.model().setHeaderData(
definition.column.value,
Qt.Orientation.Horizontal,
definition.header,
)
visible = self.settings.settings.value(definition.id, True)
self.setColumnHidden(definition.column.value, not visible)
if definition.width is not None:
self.setColumnWidth(definition.column_index, definition.width)
self.load_column_visibility()
self.horizontalHeader().setMinimumSectionSize(180)
self.setColumnWidth(definition.column.value, definition.width)
if definition.delegate is not None:
self.setItemDelegateForColumn(
definition.column.value, definition.delegate
)
self.settings.end_group()
self.model().select()
self.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
self.verticalHeader().hide()
self.setAlternatingRowColors(True)
def contextMenuEvent(self, event):
menu = QMenu(self)
for definition in self.column_definitions:
for definition in column_definitions:
if not definition.hidden_toggleable:
continue
action = menu.addAction(definition.header)
action.setCheckable(True)
action.setChecked(not self.isColumnHidden(definition.column_index))
action.setChecked(not self.isColumnHidden(definition.column.value))
action.toggled.connect(
lambda checked,
column_index=definition.column_index: self.on_column_checked(
column_index=definition.column.value: self.on_column_checked(
column_index, checked
)
)
@ -155,66 +211,47 @@ class TranscriptionTasksTableWidget(QTableWidget):
self.settings.begin_group(
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
)
for definition in self.column_definitions:
for definition in column_definitions:
self.settings.settings.setValue(
definition.id, not self.isColumnHidden(definition.column_index)
definition.id, not self.isColumnHidden(definition.column.value)
)
self.settings.end_group()
def load_column_visibility(self):
self.settings.begin_group(
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
)
for definition in self.column_definitions:
visible = self.settings.settings.value(definition.id, not definition.hidden)
self.setColumnHidden(definition.column_index, not visible)
self.settings.end_group()
def upsert_task(self, task: FileTranscriptionTask):
task_row_index = self.task_row_index(task.id)
if task_row_index is None:
self.insertRow(self.rowCount())
row_index = self.rowCount() - 1
for definition in self.column_definitions:
item = QTableWidgetItem(definition.value_getter(task))
item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
self.setItem(row_index, definition.column_index, item)
else:
for definition in self.column_definitions:
item = self.item(task_row_index, definition.column_index)
item.setText(definition.value_getter(task))
@staticmethod
def get_task_label(task: FileTranscriptionTask) -> str:
value = task.transcription_options.task.value.capitalize()
if task.transcription_options.language is not None:
value += f" ({humanize_language(task.transcription_options.language)})"
return value
def clear_task(self, task_id: int):
task_row_index = self.task_row_index(task_id)
if task_row_index is not None:
self.removeRow(task_row_index)
def task_row_index(self, task_id: int) -> int | None:
table_items_matching_task_id = [
item
for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly)
if item.column() == self.Column.TASK_ID.value
]
if len(table_items_matching_task_id) == 0:
return None
return table_items_matching_task_id[0].row()
@staticmethod
def find_task_id(index: QModelIndex):
sibling_index = index.siblingAtColumn(
TranscriptionTasksTableWidget.Column.TASK_ID.value
).data()
return int(sibling_index) if sibling_index is not None else None
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
if event.key() == Qt.Key.Key_Return:
self.return_clicked.emit()
super().keyPressEvent(event)
def selected_transcriptions(self) -> List[Transcription]:
selected = self.selectionModel().selectedRows()
return [self.transcription(row) for row in selected]
def delete_transcriptions(self, rows: List[QModelIndex]):
for row in rows:
self.model().removeRow(row.row())
self.model().submitAll()
def transcription(self, index: QModelIndex) -> Transcription:
return Transcription.from_record(self.model().record(index.row()))
def refresh_all(self):
self.model().select()
def refresh_row(self, id: UUID):
for i in range(self.model().rowCount()):
record = self.model().record(i)
if record.value("id") == str(id):
self.model().selectRow(i)
return
@staticmethod
def format_timedelta(delta: timedelta):
mm, ss = divmod(delta.seconds, 60)
result = f"{ss}s"
if mm == 0:
return result
hh, mm = divmod(mm, 60)
result = f"{mm}m {result}"
if hh == 0:
return result
return f"{hh}h {result}"

View file

@ -1,20 +1,18 @@
from PyQt6.QtCore import pyqtSignal
from PyQt6.QtGui import QAction
from PyQt6.QtWidgets import QPushButton, QWidget, QMenu, QFileDialog
from PyQt6.QtWidgets import QPushButton, QWidget, QMenu
from buzz.locale import _
from buzz.transcriber.file_transcriber import write_output
from buzz.transcriber.transcriber import (
FileTranscriptionTask,
OutputFormat,
get_output_file_path,
)
from buzz.widgets.icon import FileDownloadIcon
class ExportTranscriptionButton(QPushButton):
def __init__(self, transcription_task: FileTranscriptionTask, parent: QWidget):
on_export_triggered = pyqtSignal(OutputFormat)
def __init__(self, parent: QWidget):
super().__init__(parent)
self.transcription_task = transcription_task
export_button_menu = QMenu()
actions = [
@ -29,23 +27,4 @@ class ExportTranscriptionButton(QPushButton):
def on_menu_triggered(self, action: QAction):
output_format = OutputFormat[action.text()]
default_path = get_output_file_path(
task=self.transcription_task, output_format=output_format
)
(output_file_path, nil) = QFileDialog.getSaveFileName(
self,
_("Save File"),
default_path,
_("Text files") + f" (*.{output_format.value})",
)
if output_file_path == "":
return
write_output(
path=output_file_path,
segments=self.transcription_task.segments,
output_format=output_format,
)
self.on_export_triggered.emit(output_format)

View file

@ -1,72 +1,105 @@
import enum
from typing import List, Optional
from dataclasses import dataclass
from typing import Optional
from uuid import UUID
from PyQt6.QtCore import Qt, pyqtSignal
from PyQt6.QtWidgets import QTableWidget, QWidget, QHeaderView, QTableWidgetItem
from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex, QItemSelection
from PyQt6.QtSql import QSqlTableModel, QSqlRecord
from PyQt6.QtWidgets import (
QWidget,
QTableView,
QStyledItemDelegate,
QAbstractItemView,
)
from buzz.locale import _
from buzz.transcriber.file_transcriber import to_timestamp
from buzz.transcriber.transcriber import Segment
class TranscriptionSegmentsEditorWidget(QTableWidget):
segment_text_changed = pyqtSignal(tuple)
segment_index_selected = pyqtSignal(int)
class Column(enum.Enum):
ID = 0
END = enum.auto()
START = enum.auto()
TEXT = enum.auto()
TRANSCRIPTION_ID = enum.auto()
class Column(enum.Enum):
START = 0
END = enum.auto()
TEXT = enum.auto()
def __init__(self, segments: List[Segment], parent: Optional[QWidget]):
@dataclass
class ColDef:
id: str
header: str
column: Column
delegate: Optional[QStyledItemDelegate] = None
class TimeStampDelegate(QStyledItemDelegate):
def displayText(self, value, locale):
return to_timestamp(value)
class TranscriptionSegmentModel(QSqlTableModel):
def __init__(self, transcription_id: UUID):
super().__init__()
self.setTable("transcription_segment")
self.setEditStrategy(QSqlTableModel.EditStrategy.OnFieldChange)
self.setFilter(f"transcription_id = '{transcription_id}'")
def flags(self, index: QModelIndex):
flags = super().flags(index)
if index.column() in (Column.START.value, Column.END.value):
flags &= ~Qt.ItemFlag.ItemIsEditable
return flags
class TranscriptionSegmentsEditorWidget(QTableView):
segment_selected = pyqtSignal(QSqlRecord)
def __init__(self, transcription_id: UUID, parent: Optional[QWidget]):
super().__init__(parent)
self.segments = segments
model = TranscriptionSegmentModel(transcription_id=transcription_id)
self.setModel(model)
timestamp_delegate = TimeStampDelegate()
self.column_definitions: list[ColDef] = [
ColDef("start", _("Start"), Column.START, delegate=timestamp_delegate),
ColDef("end", _("End"), Column.END, delegate=timestamp_delegate),
ColDef("text", _("Text"), Column.TEXT),
]
for i in range(model.columnCount()):
self.hideColumn(i)
for definition in self.column_definitions:
model.setHeaderData(
definition.column.value,
Qt.Orientation.Horizontal,
definition.header,
)
self.showColumn(definition.column.value)
if definition.delegate is not None:
self.setItemDelegateForColumn(
definition.column.value, definition.delegate
)
self.setAlternatingRowColors(True)
self.setColumnCount(3)
self.verticalHeader().hide()
self.setHorizontalHeaderLabels([_("Start"), _("End"), _("Text")])
self.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeMode.ResizeToContents
)
self.setSelectionMode(QTableWidget.SelectionMode.SingleSelection)
self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
self.setSelectionMode(QTableView.SelectionMode.SingleSelection)
self.selectionModel().selectionChanged.connect(self.on_selection_changed)
model.select()
for segment in segments:
row_index = self.rowCount()
self.insertRow(row_index)
self.resizeColumnsToContents()
start_item = QTableWidgetItem(to_timestamp(segment.start))
start_item.setFlags(
start_item.flags()
& ~Qt.ItemFlag.ItemIsEditable
& ~Qt.ItemFlag.ItemIsSelectable
)
self.setItem(row_index, self.Column.START.value, start_item)
def on_selection_changed(
self, selected: QItemSelection, _deselected: QItemSelection
):
if selected.indexes():
self.segment_selected.emit(self.segment(selected.indexes()[0]))
end_item = QTableWidgetItem(to_timestamp(segment.end))
end_item.setFlags(
end_item.flags()
& ~Qt.ItemFlag.ItemIsEditable
& ~Qt.ItemFlag.ItemIsSelectable
)
self.setItem(row_index, self.Column.END.value, end_item)
def segment(self, index: QModelIndex) -> QSqlRecord:
return self.model().record(index.row())
text_item = QTableWidgetItem(segment.text)
self.setItem(row_index, self.Column.TEXT.value, text_item)
self.itemChanged.connect(self.on_item_changed)
self.itemSelectionChanged.connect(self.on_item_selection_changed)
def on_item_changed(self, item: QTableWidgetItem):
if item.column() == self.Column.TEXT.value:
self.segment_text_changed.emit((item.row(), item.text()))
def set_segment_text(self, index: int, text: str):
self.item(index, self.Column.TEXT.value).setText(text)
def on_item_selection_changed(self):
ranges = self.selectedRanges()
self.segment_index_selected.emit(ranges[0].topRow() if len(ranges) > 0 else -1)
def segments(self) -> list[QSqlRecord]:
return [self.model().record(i) for i in range(self.model().rowCount())]

View file

@ -1,25 +1,24 @@
import platform
from typing import List, Optional
from typing import Optional
from uuid import UUID
from PyQt6.QtCore import Qt, pyqtSignal
from PyQt6.QtGui import QUndoCommand, QUndoStack, QKeySequence
from PyQt6.QtCore import Qt
from PyQt6.QtMultimedia import QMediaPlayer
from PyQt6.QtSql import QSqlRecord
from PyQt6.QtWidgets import (
QWidget,
QHBoxLayout,
QLabel,
QGridLayout,
QFileDialog,
)
from buzz.action import Action
from buzz.db.entity.transcription import Transcription
from buzz.locale import _
from buzz.paths import file_path_as_title
from buzz.transcriber.transcriber import (
FileTranscriptionTask,
Segment,
)
from buzz.transcriber.file_transcriber import write_output
from buzz.transcriber.transcriber import OutputFormat, Segment
from buzz.widgets.audio_player import AudioPlayer
from buzz.widgets.icon import UndoIcon, RedoIcon
from buzz.widgets.toolbar import ToolBar
from buzz.widgets.transcription_viewer.export_transcription_button import (
ExportTranscriptionButton,
)
@ -28,84 +27,31 @@ from buzz.widgets.transcription_viewer.transcription_segments_editor_widget impo
)
class ChangeSegmentTextCommand(QUndoCommand):
def __init__(
self,
table_widget: TranscriptionSegmentsEditorWidget,
segments: List[Segment],
segment_index: int,
segment_text: str,
task_changed: pyqtSignal,
):
super().__init__()
self.table_widget = table_widget
self.segments = segments
self.segment_index = segment_index
self.segment_text = segment_text
self.task_changed = task_changed
self.previous_segment_text = self.segments[self.segment_index].text
def undo(self) -> None:
self.set_segment_text(self.previous_segment_text)
def redo(self) -> None:
self.set_segment_text(self.segment_text)
def set_segment_text(self, text: str):
# block signals before setting text so it doesn't re-trigger a new UndoCommand
self.table_widget.blockSignals(True)
self.table_widget.set_segment_text(self.segment_index, text)
self.table_widget.blockSignals(False)
self.segments[self.segment_index].text = text
self.task_changed.emit()
class TranscriptionViewerWidget(QWidget):
transcription_task: FileTranscriptionTask
task_changed = pyqtSignal()
transcription: Transcription
def __init__(
self,
transcription_task: FileTranscriptionTask,
open_transcription_output=True,
transcription: Transcription,
parent: Optional["QWidget"] = None,
flags: Qt.WindowType = Qt.WindowType.Widget,
) -> None:
super().__init__(parent, flags)
self.transcription_task = transcription_task
self.open_transcription_output = open_transcription_output
self.transcription = transcription
self.setMinimumWidth(800)
self.setMinimumHeight(500)
self.setWindowTitle(file_path_as_title(transcription_task.file_path))
self.undo_stack = QUndoStack()
undo_action = self.undo_stack.createUndoAction(self, _("Undo"))
undo_action.setShortcuts(QKeySequence.StandardKey.Undo)
undo_action.setIcon(UndoIcon(parent=self))
undo_action.setToolTip(Action.get_tooltip(undo_action))
redo_action = self.undo_stack.createRedoAction(self, _("Redo"))
redo_action.setShortcuts(QKeySequence.StandardKey.Redo)
redo_action.setIcon(RedoIcon(parent=self))
redo_action.setToolTip(Action.get_tooltip(redo_action))
toolbar = ToolBar()
toolbar.addActions([undo_action, redo_action])
self.setWindowTitle(file_path_as_title(transcription.file))
self.table_widget = TranscriptionSegmentsEditorWidget(
segments=transcription_task.segments, parent=self
transcription_id=UUID(hex=transcription.id), parent=self
)
self.table_widget.segment_text_changed.connect(self.on_segment_text_changed)
self.table_widget.segment_index_selected.connect(self.on_segment_index_selected)
self.table_widget.segment_selected.connect(self.on_segment_selected)
self.audio_player: Optional[AudioPlayer] = None
if platform.system() != "Linux":
self.audio_player = AudioPlayer(file_path=transcription_task.file_path)
self.audio_player = AudioPlayer(file_path=transcription.file)
self.audio_player.position_ms_changed.connect(
self.on_audio_player_position_ms_changed
)
@ -118,12 +64,10 @@ class TranscriptionViewerWidget(QWidget):
buttons_layout = QHBoxLayout()
buttons_layout.addStretch()
export_button = ExportTranscriptionButton(
transcription_task=transcription_task, parent=self
)
export_button = ExportTranscriptionButton(parent=self)
export_button.on_export_triggered.connect(self.on_export_triggered)
layout = QGridLayout(self)
layout.setMenuBar(toolbar)
layout.addWidget(self.table_widget, 0, 0, 1, 2)
if self.audio_player is not None:
@ -133,33 +77,56 @@ class TranscriptionViewerWidget(QWidget):
self.setLayout(layout)
def on_segment_text_changed(self, event: tuple):
segment_index, segment_text = event
self.undo_stack.push(
ChangeSegmentTextCommand(
table_widget=self.table_widget,
segments=self.transcription_task.segments,
segment_index=segment_index,
segment_text=segment_text,
task_changed=self.task_changed,
)
def on_export_triggered(self, output_format: OutputFormat) -> None:
default_path = self.transcription.get_output_file_path(
output_format=output_format
)
def on_segment_index_selected(self, index: int):
selected_segment = self.transcription_task.segments[index]
if self.audio_player is not None:
self.audio_player.set_range((selected_segment.start, selected_segment.end))
(output_file_path, nil) = QFileDialog.getSaveFileName(
self,
_("Save File"),
default_path,
_("Text files") + f" (*.{output_format.value})",
)
if output_file_path == "":
return
segments = [
Segment(
start=segment.value("start_time"),
end=segment.value("end_time"),
text=segment.value("text"),
)
for segment in self.table_widget.segments()
]
write_output(
path=output_file_path,
segments=segments,
output_format=output_format,
)
def on_segment_selected(self, segment: QSqlRecord):
if self.audio_player is not None and (
self.audio_player.media_player.playbackState()
== QMediaPlayer.PlaybackState.PlayingState
):
self.audio_player.set_range(
(segment.value("start_time"), segment.value("end_time"))
)
def on_audio_player_position_ms_changed(self, position_ms: int) -> None:
current_segment_index: Optional[int] = next(
segments = self.table_widget.segments()
current_segment = next(
(
i
for i, segment in enumerate(self.transcription_task.segments)
if segment.start <= position_ms < segment.end
segment
for segment in segments
if segment.value("start_time")
<= position_ms
< segment.value("end_time")
),
None,
)
if current_segment_index is not None:
self.current_segment_label.setText(
self.transcription_task.segments[current_segment_index].text
)
if current_segment is not None:
self.current_segment_label.setText(current_segment.value("text"))

1123
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -14,7 +14,6 @@ packages = [
[tool.poetry.dependencies]
python = ">=3.9.13,<3.11"
sounddevice = "^0.4.5"
appdirs = "^1.4.4"
humanize = "^4.4.0"
PyQt6 = "^6.4.0"
openai = "^1.6.1"

40
tests/conftest.py Normal file
View file

@ -0,0 +1,40 @@
import os
import pytest
from PyQt6.QtSql import QSqlDatabase
from _pytest.fixtures import SubRequest
from buzz.db.dao.transcription_dao import TranscriptionDAO
from buzz.db.dao.transcription_segment_dao import TranscriptionSegmentDAO
from buzz.db.db import setup_test_db
from buzz.db.service.transcription_service import TranscriptionService
@pytest.fixture()
def db() -> QSqlDatabase:
db = setup_test_db()
yield db
db.close()
os.remove(db.databaseName())
@pytest.fixture()
def transcription_dao(db, request: SubRequest) -> TranscriptionDAO:
dao = TranscriptionDAO(db)
if hasattr(request, "param"):
transcriptions = request.param
for transcription in transcriptions:
dao.insert(transcription)
return dao
@pytest.fixture()
def transcription_service(
transcription_dao, transcription_segment_dao
) -> TranscriptionService:
return TranscriptionService(transcription_dao, transcription_segment_dao)
@pytest.fixture()
def transcription_segment_dao(db) -> TranscriptionSegmentDAO:
return TranscriptionSegmentDAO(db)

View file

@ -54,13 +54,15 @@ class TestWhisperFileTranscriber:
expected_file_path: str,
):
file_path = get_output_file_path(
task=FileTranscriptionTask(
file_path=file_path,
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
model_path="",
file_path=file_path,
language=None,
task=Task.TRANSLATE,
model=TranscriptionModel(
model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY,
),
output_format=output_format,
output_directory="",
export_file_name_template="{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}",
)
assert file_path == expected_file_path
@ -87,15 +89,15 @@ class TestWhisperFileTranscriber:
"{{ input_file_name }} (Translated on {{ date_time }})"
)
srt = get_output_file_path(
task=FileTranscriptionTask(
file_path=file_path,
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(
file_paths=[],
),
model_path="",
file_path=file_path,
language=None,
task=Task.TRANSLATE,
model=TranscriptionModel(
model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY,
),
output_format=OutputFormat.TXT,
output_directory="",
export_file_name_template=export_file_name_template,
)
@ -103,15 +105,15 @@ class TestWhisperFileTranscriber:
assert srt.endswith(".txt")
srt = get_output_file_path(
task=FileTranscriptionTask(
file_path=file_path,
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(
file_paths=[],
),
model_path="",
file_path=file_path,
language=None,
task=Task.TRANSLATE,
model=TranscriptionModel(
model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY,
),
output_format=OutputFormat.SRT,
output_directory="",
export_file_name_template=export_file_name_template,
)
assert srt.startswith(expected_starts_with)

View file

@ -6,50 +6,26 @@ import pytest
from PyQt6.QtCore import QSize, Qt
from PyQt6.QtGui import QKeyEvent, QAction
from PyQt6.QtWidgets import (
QTableWidget,
QMessageBox,
QPushButton,
QToolBar,
QMenuBar,
QTableView,
)
from _pytest.fixtures import SubRequest
from pytestqt.qtbot import QtBot
from buzz.cache import TasksCache
from buzz.transcriber.transcriber import (
FileTranscriptionTask,
TranscriptionOptions,
FileTranscriptionOptions,
)
from buzz.db.entity.transcription import Transcription
from buzz.db.service.transcription_service import TranscriptionService
from buzz.widgets.main_window import MainWindow
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
TranscriptionViewerWidget,
)
mock_tasks = [
FileTranscriptionTask(
file_path="",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
model_path="",
status=FileTranscriptionTask.Status.COMPLETED,
),
FileTranscriptionTask(
file_path="",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
model_path="",
status=FileTranscriptionTask.Status.CANCELED,
),
FileTranscriptionTask(
file_path="",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
model_path="",
status=FileTranscriptionTask.Status.FAILED,
error="Error",
),
mock_transcriptions: List[Transcription] = [
Transcription(status="completed"),
Transcription(status="canceled"),
Transcription(status="failed", error_message="Error"),
]
@ -58,37 +34,41 @@ def get_test_asset(filename: str):
class TestMainWindow:
def test_should_set_window_title_and_icon(self, qtbot):
window = MainWindow()
def test_should_set_window_title_and_icon(self, qtbot, transcription_service):
window = MainWindow(transcription_service)
qtbot.add_widget(window)
assert window.windowTitle() == "Buzz"
assert window.windowIcon().pixmap(QSize(64, 64)).isNull() is False
window.close()
def test_should_run_file_transcription_task(self, qtbot: QtBot, tasks_cache):
window = MainWindow(tasks_cache=tasks_cache)
def test_should_run_file_transcription_task(
self, qtbot: QtBot, transcription_service
):
window = MainWindow(transcription_service)
self.import_file_and_start_transcription(window)
self._import_file_and_start_transcription(window)
open_transcript_action = self._get_toolbar_action(window, "Open Transcript")
assert open_transcript_action.isEnabled() is False
table_widget: QTableWidget = window.findChild(QTableWidget)
table_widget = self._get_tasks_table(window)
qtbot.wait_until(
self.get_assert_task_status_callback(table_widget, 0, "Completed"),
self._get_assert_task_status_callback(table_widget, 0, "completed"),
timeout=2 * 60 * 1000,
)
table_widget.setCurrentIndex(
table_widget.indexFromItem(table_widget.item(0, 1))
)
table_widget.setCurrentIndex(table_widget.model().index(0, 0))
assert open_transcript_action.isEnabled()
window.close()
@staticmethod
def _get_tasks_table(window: MainWindow) -> QTableView:
return window.findChild(QTableView)
def test_should_run_url_import_file_transcription_task(
self, qtbot: QtBot, tasks_cache
self, qtbot: QtBot, db, transcription_service
):
window = MainWindow(tasks_cache=tasks_cache)
window = MainWindow(transcription_service)
menu: QMenuBar = window.menuBar()
file_action = menu.actions()[0]
import_url_action: QAction = file_action.menu().actions()[1]
@ -105,24 +85,26 @@ class TestMainWindow:
run_button: QPushButton = file_transcriber_widget.findChild(QPushButton)
run_button.click()
table_widget: QTableWidget = window.findChild(QTableWidget)
table_widget = self._get_tasks_table(window)
qtbot.wait_until(
self.get_assert_task_status_callback(table_widget, 0, "Completed"),
self._get_assert_task_status_callback(table_widget, 0, "completed"),
timeout=2 * 60 * 1000,
)
window.close()
def test_should_run_and_cancel_transcription_task(self, qtbot, tasks_cache):
window = MainWindow(tasks_cache=tasks_cache)
def test_should_run_and_cancel_transcription_task(
self, qtbot, db, transcription_service
):
window = MainWindow(transcription_service)
qtbot.add_widget(window)
self.import_file_and_start_transcription(window, long_audio=True)
self._import_file_and_start_transcription(window, long_audio=True)
table_widget: QTableWidget = window.findChild(QTableWidget)
table_widget = self._get_tasks_table(window)
qtbot.wait_until(
self.get_assert_task_status_callback(table_widget, 0, "In Progress"),
self._get_assert_task_status_callback(table_widget, 0, "in_progress"),
timeout=2 * 60 * 1000,
)
@ -131,7 +113,7 @@ class TestMainWindow:
window.toolbar.stop_transcription_action.trigger()
qtbot.wait_until(
self.get_assert_task_status_callback(table_widget, 0, "Canceled"),
self._get_assert_task_status_callback(table_widget, 0, "canceled"),
timeout=60 * 1000,
)
@ -141,60 +123,72 @@ class TestMainWindow:
window.close()
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
def test_should_load_tasks_from_cache(self, qtbot, tasks_cache):
window = MainWindow(tasks_cache=tasks_cache)
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
def test_should_load_tasks_from_cache(
self, qtbot, transcription_dao, transcription_segment_dao
):
window = MainWindow(
TranscriptionService(transcription_dao, transcription_segment_dao)
)
qtbot.add_widget(window)
table_widget: QTableWidget = window.findChild(QTableWidget)
assert table_widget.rowCount() == 3
table_widget = self._get_tasks_table(window)
assert table_widget.model().rowCount() == 3
assert table_widget.item(0, 4).text() == "Completed"
assert self._get_status(table_widget, 0) == "completed"
table_widget.selectRow(0)
assert window.toolbar.open_transcript_action.isEnabled()
assert table_widget.item(1, 4).text() == "Canceled"
assert self._get_status(table_widget, 1) == "canceled"
table_widget.selectRow(1)
assert window.toolbar.open_transcript_action.isEnabled() is False
assert table_widget.item(2, 4).text() == "Failed (Error)"
assert self._get_status(table_widget, 2) == "failed"
table_widget.selectRow(2)
assert window.toolbar.open_transcript_action.isEnabled() is False
window.close()
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
def test_should_clear_history_with_rows_selected(self, qtbot, tasks_cache):
window = MainWindow(tasks_cache=tasks_cache)
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
def test_should_clear_history_with_rows_selected(
self, qtbot, transcription_dao, transcription_segment_dao
):
window = MainWindow(
TranscriptionService(transcription_dao, transcription_segment_dao)
)
qtbot.add_widget(window)
table_widget: QTableWidget = window.findChild(QTableWidget)
table_widget = self._get_tasks_table(window)
table_widget.selectAll()
with patch("PyQt6.QtWidgets.QMessageBox.question") as question_message_box_mock:
question_message_box_mock.return_value = QMessageBox.StandardButton.Yes
window.toolbar.clear_history_action.trigger()
assert table_widget.rowCount() == 0
assert table_widget.model().rowCount() == 0
window.close()
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
def test_should_have_clear_history_action_disabled_with_no_rows_selected(
self, qtbot, tasks_cache
self, qtbot, transcription_dao, transcription_segment_dao
):
window = MainWindow(tasks_cache=tasks_cache)
window = MainWindow(
TranscriptionService(transcription_dao, transcription_segment_dao)
)
qtbot.add_widget(window)
assert window.toolbar.clear_history_action.isEnabled() is False
window.close()
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
def test_should_open_transcription_viewer_when_menu_action_is_clicked(
self, qtbot, tasks_cache
self, qtbot, transcription_dao, transcription_segment_dao
):
window = MainWindow(tasks_cache=tasks_cache)
window = MainWindow(
TranscriptionService(transcription_dao, transcription_segment_dao)
)
qtbot.add_widget(window)
table_widget: QTableWidget = window.findChild(QTableWidget)
table_widget = self._get_tasks_table(window)
table_widget.selectRow(0)
window.toolbar.open_transcript_action.trigger()
@ -204,14 +198,16 @@ class TestMainWindow:
window.close()
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
def test_should_open_transcription_viewer_when_return_clicked(
self, qtbot, tasks_cache
self, qtbot, transcription_dao, transcription_segment_dao
):
window = MainWindow(tasks_cache=tasks_cache)
window = MainWindow(
TranscriptionService(transcription_dao, transcription_segment_dao)
)
qtbot.add_widget(window)
table_widget: QTableWidget = window.findChild(QTableWidget)
table_widget = self._get_tasks_table(window)
table_widget.selectRow(0)
table_widget.keyPressEvent(
QKeyEvent(
@ -227,18 +223,20 @@ class TestMainWindow:
window.close()
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
@pytest.mark.parametrize("transcription_dao", [mock_transcriptions], indirect=True)
def test_should_have_open_transcript_action_disabled_with_no_rows_selected(
self, qtbot, tasks_cache
self, qtbot, transcription_dao, transcription_segment_dao
):
window = MainWindow(tasks_cache=tasks_cache)
window = MainWindow(
TranscriptionService(transcription_dao, transcription_segment_dao)
)
qtbot.add_widget(window)
assert window.toolbar.open_transcript_action.isEnabled() is False
window.close()
@staticmethod
def import_file_and_start_transcription(
def _import_file_and_start_transcription(
window: MainWindow, long_audio: bool = False
):
with patch(
@ -264,34 +262,24 @@ class TestMainWindow:
run_button.click()
@staticmethod
def get_assert_task_status_callback(
table_widget: QTableWidget,
def _get_assert_task_status_callback(
table_widget: QTableView,
row_index: int,
expected_status: str,
long_audio: bool = False,
):
def assert_task_status():
assert table_widget.rowCount() > 0
assert (
table_widget.item(row_index, 1).text() == "audio-long.mp3"
if long_audio
else "whisper-french.mp3"
assert table_widget.model().rowCount() > 0
assert expected_status in TestMainWindow._get_status(
table_widget, row_index
)
assert expected_status in table_widget.item(row_index, 4).text()
return assert_task_status
@staticmethod
def _get_status(table_widget: QTableView, row_index: int):
return table_widget.model().index(row_index, 9).data()
@staticmethod
def _get_toolbar_action(window: MainWindow, text: str):
toolbar: QToolBar = window.findChild(QToolBar)
return [action for action in toolbar.actions() if action.text() == text][0]
@pytest.fixture()
def tasks_cache(tmp_path, request: SubRequest):
cache = TasksCache(cache_dir=str(tmp_path))
if hasattr(request, "param"):
tasks: List[FileTranscriptionTask] = request.param
cache.save(tasks)
yield cache
cache.clear()

View file

@ -1,115 +1,7 @@
import datetime
from pytestqt.qtbot import QtBot
from buzz.transcriber.transcriber import (
FileTranscriptionTask,
TranscriptionOptions,
FileTranscriptionOptions,
)
from buzz.widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
class TestTranscriptionTasksTableWidget:
def test_upsert_task(self, qtbot: QtBot):
def test_can_create(self, qtbot, reset_settings):
widget = TranscriptionTasksTableWidget()
qtbot.add_widget(widget)
task = FileTranscriptionTask(
id=0,
file_path="testdata/whisper-french.mp3",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=["testdata/whisper-french.mp3"]
),
model_path="",
status=FileTranscriptionTask.Status.QUEUED,
)
task.queued_at = datetime.datetime(2023, 4, 12, 0, 0, 0)
task.started_at = datetime.datetime(2023, 4, 12, 0, 0, 5)
widget.upsert_task(task)
assert widget.rowCount() == 1
self.assert_row_text(
widget, 0, "whisper-french.mp3", "Whisper (Tiny)", "Transcribe", "Queued"
)
task.status = FileTranscriptionTask.Status.IN_PROGRESS
task.fraction_completed = 0.3524
widget.upsert_task(task)
assert widget.rowCount() == 1
self.assert_row_text(
widget,
0,
"whisper-french.mp3",
"Whisper (Tiny)",
"Transcribe",
"In Progress (35%)",
)
task.status = FileTranscriptionTask.Status.COMPLETED
task.completed_at = datetime.datetime(2023, 4, 12, 0, 0, 10)
widget.upsert_task(task)
assert widget.rowCount() == 1
self.assert_row_text(
widget,
0,
"whisper-french.mp3",
"Whisper (Tiny)",
"Transcribe",
"Completed (5s)",
)
def test_upsert_task_no_timings(self, qtbot: QtBot):
widget = TranscriptionTasksTableWidget()
qtbot.add_widget(widget)
task = FileTranscriptionTask(
id=0,
file_path="testdata/whisper-french.mp3",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=["testdata/whisper-french.mp3"]
),
model_path="",
status=FileTranscriptionTask.Status.COMPLETED,
)
widget.upsert_task(task)
assert widget.rowCount() == 1
self.assert_row_text(
widget, 0, "whisper-french.mp3", "Whisper (Tiny)", "Transcribe", "Completed"
)
def test_toggle_column_visibility(self, qtbot, reset_settings):
widget = TranscriptionTasksTableWidget()
qtbot.add_widget(widget)
assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.TASK_ID.value)
assert not widget.isColumnHidden(
TranscriptionTasksTableWidget.Column.FILE_NAME.value
)
assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.MODEL.value)
assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.TASK.value)
assert not widget.isColumnHidden(
TranscriptionTasksTableWidget.Column.STATUS.value
)
# TODO: open context menu and toggle column visibility
def assert_row_text(
self,
widget: TranscriptionTasksTableWidget,
row: int,
filename: str,
model: str,
task: str,
status: str,
):
assert widget.item(row, 1).text() == filename
assert widget.item(row, 2).text() == model
assert widget.item(row, 3).text() == task
assert widget.item(row, 4).text() == status

View file

@ -1,16 +1,15 @@
import pathlib
import uuid
from unittest.mock import patch
import pytest
from PyQt6.QtWidgets import QPushButton, QToolBar
from PyQt6.QtWidgets import QPushButton
from pytestqt.qtbot import QtBot
from buzz.transcriber.transcriber import (
FileTranscriptionTask,
FileTranscriptionOptions,
TranscriptionOptions,
Segment,
)
from buzz.db.entity.transcription import Transcription
from buzz.db.entity.transcription_segment import TranscriptionSegment
from buzz.model_loader import ModelType, WhisperModelSize
from buzz.transcriber.transcriber import Task
from buzz.widgets.transcription_viewer.transcription_segments_editor_widget import (
TranscriptionSegmentsEditorWidget,
)
@ -21,22 +20,29 @@ from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
class TestTranscriptionViewerWidget:
@pytest.fixture()
def task(self) -> FileTranscriptionTask:
return FileTranscriptionTask(
id=0,
file_path="testdata/whisper-french.mp3",
file_transcription_options=FileTranscriptionOptions(
file_paths=["testdata/whisper-french.mp3"]
),
transcription_options=TranscriptionOptions(),
segments=[Segment(40, 299, "Bien"), Segment(299, 329, "venue dans")],
model_path="",
def transcription(
self, transcription_dao, transcription_segment_dao
) -> Transcription:
id = uuid.uuid4()
transcription_dao.insert(
Transcription(
id=str(id),
status="completed",
file="testdata/whisper-french.mp3",
task=Task.TRANSCRIBE.value,
model_type=ModelType.WHISPER.value,
whisper_model_size=WhisperModelSize.SMALL.value,
)
)
transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", str(id)))
transcription_segment_dao.insert(
TranscriptionSegment(299, 329, "venue dans", str(id))
)
def test_should_display_segments(self, qtbot: QtBot, task):
widget = TranscriptionViewerWidget(
transcription_task=task, open_transcription_output=False
)
return transcription_dao.find_by_id(str(id))
def test_should_display_segments(self, qtbot: QtBot, transcription):
widget = TranscriptionViewerWidget(transcription)
qtbot.add_widget(widget)
assert widget.windowTitle() == "whisper-french.mp3"
@ -44,37 +50,23 @@ class TestTranscriptionViewerWidget:
editor = widget.findChild(TranscriptionSegmentsEditorWidget)
assert isinstance(editor, TranscriptionSegmentsEditorWidget)
assert editor.item(0, 0).text() == "00:00:00.040"
assert editor.item(0, 1).text() == "00:00:00.299"
assert editor.item(0, 2).text() == "Bien"
assert editor.model().index(0, 1).data() == 299
assert editor.model().index(0, 2).data() == 40
assert editor.model().index(0, 3).data() == "Bien"
def test_should_update_segment_text(self, qtbot, task):
widget = TranscriptionViewerWidget(
transcription_task=task, open_transcription_output=False
)
def test_should_update_segment_text(self, qtbot, transcription):
widget = TranscriptionViewerWidget(transcription)
qtbot.add_widget(widget)
editor = widget.findChild(TranscriptionSegmentsEditorWidget)
assert isinstance(editor, TranscriptionSegmentsEditorWidget)
# Change text
editor.item(0, 2).setText("Biens")
assert task.segments[0].text == "Biens"
editor.model().setData(editor.model().index(0, 3), "Biens")
# Undo
toolbar = widget.findChild(QToolBar)
undo_action, redo_action = toolbar.actions()
undo_action.trigger()
assert task.segments[0].text == "Bien"
redo_action.trigger()
assert task.segments[0].text == "Biens"
def test_should_export_segments(self, tmp_path: pathlib.Path, qtbot: QtBot, task):
widget = TranscriptionViewerWidget(
transcription_task=task, open_transcription_output=False
)
def test_should_export_segments(
self, tmp_path: pathlib.Path, qtbot: QtBot, transcription
):
widget = TranscriptionViewerWidget(transcription)
qtbot.add_widget(widget)
export_button = widget.findChild(QPushButton)
@ -87,5 +79,5 @@ class TestTranscriptionViewerWidget:
save_file_name_mock.return_value = (str(output_file_path), "")
export_button.menu().actions()[0].trigger()
output_file = open(output_file_path, "r", encoding="utf-8")
assert "Bien\nvenue dans" in output_file.read()
with open(output_file_path, encoding="utf-8") as output_file:
assert "Bien\nvenue dans" in output_file.read()