From 10e74edf89ccfdd94a1368957226cd07ed0132ce Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Thu, 6 Nov 2025 13:51:01 +0200
Subject: [PATCH 01/73] Add test timeout (#1277)
---
README.md | 4 ++--
buzz/db/migrator.py | 3 ++-
buzz/file_transcriber_queue_worker.py | 2 +-
buzz/transcriber/recording_transcriber.py | 4 ++--
buzz/transcriber/whisper_file_transcriber.py | 6 +++---
buzz/translator.py | 3 ++-
buzz/widgets/main_window.py | 2 +-
buzz/widgets/recording_transcriber_widget.py | 4 ++++
.../transcription_viewer_widget.py | 2 +-
pytest.ini | 2 ++
tests/cli_test.py | 2 +-
tests/translator_test.py | 10 +++++-----
tests/widgets/export_transcription_menu_test.py | 2 +-
.../transcription_segments_editor_widget_test.py | 2 +-
.../transcription_viewer_widget_additional_test.py | 2 +-
tests/widgets/transcription_viewer_test.py | 2 +-
16 files changed, 30 insertions(+), 22 deletions(-)
diff --git a/README.md b/README.md
index c53ee4a6..173d25e4 100644
--- a/README.md
+++ b/README.md
@@ -41,11 +41,11 @@ Install with [brew utility](https://brew.sh/)
brew install --cask buzz
```
-Or download the `.dmg` from the [releases page](https://github.com/chidiwilliams/buzz/releases/latest).
+Or download the `.dmg` from the [SourceForge](https://sourceforge.net/projects/buzz-captions/files/).
### Windows
-Download and run the `.exe` from the [releases page](https://github.com/chidiwilliams/buzz/releases/latest).
+Get the installation files from the [SourceForge](https://sourceforge.net/projects/buzz-captions/files/).
App is not signed, you will get a warning when you install it. Select `More info` -> `Run anyway`.
diff --git a/buzz/db/migrator.py b/buzz/db/migrator.py
index 0fa6b043..d36f9b34 100644
--- a/buzz/db/migrator.py
+++ b/buzz/db/migrator.py
@@ -69,7 +69,8 @@ class DBMigrator:
msg_argv += (args,)
else:
args = []
- logging.info(msg_tmpl, *msg_argv)
+ # Uncomment this to get debugging information
+ # logging.info(msg_tmpl, *msg_argv)
self.db.execute(sql, args)
self.n_changes += 1
diff --git a/buzz/file_transcriber_queue_worker.py b/buzz/file_transcriber_queue_worker.py
index 24fe8013..f6cf91fb 100644
--- a/buzz/file_transcriber_queue_worker.py
+++ b/buzz/file_transcriber_queue_worker.py
@@ -139,7 +139,7 @@ class FileTranscriberQueueWorker(QObject):
self.current_transcriber.stop()
if self.current_transcriber_thread is not None:
- if not self.current_transcriber_thread.wait(3000):
+ if not self.current_transcriber_thread.wait(5000):
logging.warning("Transcriber thread did not terminate gracefully")
self.current_transcriber_thread.terminate()
diff --git a/buzz/transcriber/recording_transcriber.py b/buzz/transcriber/recording_transcriber.py
index 5c71b8ba..8e5cc3d1 100644
--- a/buzz/transcriber/recording_transcriber.py
+++ b/buzz/transcriber/recording_transcriber.py
@@ -326,7 +326,7 @@ class RecordingTranscriber(QObject):
self.is_running = False
if self.process and self.process.poll() is None:
self.process.terminate()
- self.process.wait()
+ self.process.wait(5000)
def start_local_whisper_server(self):
self.transcription.emit(_("Starting Whisper.cpp..."))
@@ -416,4 +416,4 @@ class RecordingTranscriber(QObject):
def __del__(self):
if self.process and self.process.poll() is None:
self.process.terminate()
- self.process.wait()
\ No newline at end of file
+ self.process.wait(5000)
\ No newline at end of file
diff --git a/buzz/transcriber/whisper_file_transcriber.py b/buzz/transcriber/whisper_file_transcriber.py
index 1b2ea99e..c5533397 100644
--- a/buzz/transcriber/whisper_file_transcriber.py
+++ b/buzz/transcriber/whisper_file_transcriber.py
@@ -274,11 +274,11 @@ class WhisperFileTranscriber(FileTranscriber):
if self.started_process:
self.current_process.terminate()
# Use timeout to avoid hanging indefinitely
- self.current_process.join(timeout=5)
+ self.current_process.join(timeout=10)
if self.current_process.is_alive():
logging.warning("Process didn't terminate gracefully, force killing")
self.current_process.kill()
- self.current_process.join(timeout=2)
+ self.current_process.join(timeout=5)
# Close pipes to unblock the read_line thread
try:
@@ -291,7 +291,7 @@ class WhisperFileTranscriber(FileTranscriber):
# Join read_line_thread with timeout to prevent hanging
if self.read_line_thread and self.read_line_thread.is_alive():
- self.read_line_thread.join(timeout=3)
+ self.read_line_thread.join(timeout=5)
if self.read_line_thread.is_alive():
logging.warning("Read line thread didn't terminate gracefully")
diff --git a/buzz/translator.py b/buzz/translator.py
index 0243aacf..56a816ea 100644
--- a/buzz/translator.py
+++ b/buzz/translator.py
@@ -68,7 +68,8 @@ class Translator(QObject):
messages=[
{"role": "system", "content": self.transcription_options.llm_prompt},
{"role": "user", "content": transcript}
- ]
+ ],
+ timeout=30.0
)
except Exception as e:
completion = None
diff --git a/buzz/widgets/main_window.py b/buzz/widgets/main_window.py
index ed471ec6..0ca97cd0 100644
--- a/buzz/widgets/main_window.py
+++ b/buzz/widgets/main_window.py
@@ -425,7 +425,7 @@ class MainWindow(QMainWindow):
self.transcriber_worker.stop()
self.transcriber_thread.quit()
- self.transcriber_thread.wait()
+ self.transcriber_thread.wait(5000) # Wait up to 5 seconds
if self.transcription_viewer_widget is not None:
self.transcription_viewer_widget.close()
diff --git a/buzz/widgets/recording_transcriber_widget.py b/buzz/widgets/recording_transcriber_widget.py
index b336121b..80ae166d 100644
--- a/buzz/widgets/recording_transcriber_widget.py
+++ b/buzz/widgets/recording_transcriber_widget.py
@@ -624,6 +624,10 @@ class RecordingTranscriberWidget(QWidget):
if self.translator is not None:
self.translator.stop()
+ if self.translation_thread is not None:
+ self.translation_thread.quit()
+ self.translation_thread.wait(35_000) # Wait up to 35 seconds
+
self.settings.set_value(
Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE,
self.transcription_options.language,
diff --git a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
index 5b9abeab..ba53226a 100644
--- a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
+++ b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
@@ -1348,7 +1348,7 @@ class TranscriptionViewerWidget(QWidget):
self.translator.stop()
self.translation_thread.quit()
- self.translation_thread.wait()
+ self.translation_thread.wait(35_000) # Wait up to 35 seconds, translation thread also has timeouts, wait longer
super().closeEvent(event)
diff --git a/pytest.ini b/pytest.ini
index ad52348a..b1ef248a 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -5,5 +5,7 @@ qt_api=pyqt6
log_format = %(asctime)s %(levelname)s %(module)s::%(funcName)s %(message)s
log_date_format = %Y-%m-%d %H:%M:%S
addopts = -x
+timeout = 600
+timeout_method = thread
markers =
timeout: set a timeout on a test function.
\ No newline at end of file
diff --git a/tests/cli_test.py b/tests/cli_test.py
index 9bd077d1..7887acf3 100644
--- a/tests/cli_test.py
+++ b/tests/cli_test.py
@@ -20,7 +20,7 @@ class TestCLI:
"--task",
"transcribe",
"--model-size",
- "small",
+ "tiny",
"--output-directory",
mkdtemp(),
"--txt",
diff --git a/tests/translator_test.py b/tests/translator_test.py
index 56db2fc3..6c0f87d6 100644
--- a/tests/translator_test.py
+++ b/tests/translator_test.py
@@ -13,7 +13,7 @@ from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDi
class TestTranslator:
@patch('buzz.translator.OpenAI', autospec=True)
@patch('buzz.translator.queue.Queue', autospec=True)
- def test_start(self, mock_queue, mock_openai):
+ def test_start(self, mock_queue, mock_openai, qtbot):
def side_effect(*args, **kwargs):
side_effect.call_count += 1
@@ -106,11 +106,11 @@ class TestTranslator:
if self.translator is not None:
self.translator.stop()
- self.translator.deleteLater()
if self.translation_thread is not None:
self.translation_thread.quit()
- self.translation_thread.deleteLater()
+ # Wait for the thread to actually finish before cleanup
+ self.translation_thread.wait()
- # Wait to clean-up threads
- time.sleep(3)
+ # Note: translator and translation_thread will be automatically deleted
+ # via the deleteLater() connections set up earlier
diff --git a/tests/widgets/export_transcription_menu_test.py b/tests/widgets/export_transcription_menu_test.py
index 7c15f1c4..30a735be 100644
--- a/tests/widgets/export_transcription_menu_test.py
+++ b/tests/widgets/export_transcription_menu_test.py
@@ -32,7 +32,7 @@ class TestExportTranscriptionMenu:
file=test_audio_path,
task=Task.TRANSCRIBE.value,
model_type=ModelType.WHISPER.value,
- whisper_model_size=WhisperModelSize.SMALL.value,
+ whisper_model_size=WhisperModelSize.TINY.value,
)
)
transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", "", str(id)))
diff --git a/tests/widgets/transcription_viewer/transcription_segments_editor_widget_test.py b/tests/widgets/transcription_viewer/transcription_segments_editor_widget_test.py
index 5e4fab68..ac8036a9 100644
--- a/tests/widgets/transcription_viewer/transcription_segments_editor_widget_test.py
+++ b/tests/widgets/transcription_viewer/transcription_segments_editor_widget_test.py
@@ -289,7 +289,7 @@ class TestTranscriptionSegmentsEditorWidget:
file=test_audio_path,
task=Task.TRANSCRIBE.value,
model_type=ModelType.WHISPER.value,
- whisper_model_size=WhisperModelSize.SMALL.value,
+ whisper_model_size=WhisperModelSize.TINY.value,
)
)
transcription_segment_dao.insert(
diff --git a/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py b/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
index cb1ceb66..8d34460c 100644
--- a/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
+++ b/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
@@ -32,7 +32,7 @@ class TestTranscriptionViewerWidgetAdditional:
file=test_audio_path,
task=Task.TRANSCRIBE.value,
model_type=ModelType.WHISPER.value,
- whisper_model_size=WhisperModelSize.SMALL.value,
+ whisper_model_size=WhisperModelSize.TINY.value,
)
)
transcription_segment_dao.insert(
diff --git a/tests/widgets/transcription_viewer_test.py b/tests/widgets/transcription_viewer_test.py
index ebc5ac01..13d87bc8 100644
--- a/tests/widgets/transcription_viewer_test.py
+++ b/tests/widgets/transcription_viewer_test.py
@@ -42,7 +42,7 @@ class TestTranscriptionViewerWidget:
file=test_audio_path,
task=Task.TRANSCRIBE.value,
model_type=ModelType.WHISPER.value,
- whisper_model_size=WhisperModelSize.SMALL.value,
+ whisper_model_size=WhisperModelSize.TINY.value,
)
)
transcription_segment_dao.insert(TranscriptionSegment(40, 299, "Bien", "", str(id)))
From 79d8aadf2f38dc865ba8502491521d0c2f6af004 Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Sat, 8 Nov 2025 21:21:19 +0200
Subject: [PATCH 02/73] Inline demucs (#1279)
---
.github/workflows/ci.yml | 8 +
Buzz.spec | 1 -
Makefile | 3 +-
buzz/__version__.py | 2 +-
buzz/widgets/main_window.py | 5 +-
buzz/widgets/recording_transcriber_widget.py | 5 +-
.../transcription_viewer_widget.py | 6 +-
demucs/Readme.md | 1 +
demucs/__init__.py | 7 +
demucs/__main__.py | 10 +
demucs/api.py | 393 ++++++++
demucs/apply.py | 322 +++++++
demucs/audio.py | 266 ++++++
demucs/audio_legacy.py | 17 +
demucs/augment.py | 111 +++
demucs/demucs.py | 447 ++++++++++
demucs/distrib.py | 100 +++
demucs/ema.py | 66 ++
demucs/evaluate.py | 174 ++++
demucs/grids/__init__.py | 0
demucs/grids/_explorers.py | 64 ++
demucs/grids/mdx.py | 33 +
demucs/grids/mdx_extra.py | 36 +
demucs/grids/mdx_refine.py | 34 +
demucs/grids/mmi.py | 69 ++
demucs/grids/mmi_ft.py | 55 ++
demucs/grids/repro.py | 50 ++
demucs/grids/repro_ft.py | 46 +
demucs/grids/sdx23.py | 19 +
demucs/hdemucs.py | 796 +++++++++++++++++
demucs/htdemucs.py | 661 ++++++++++++++
demucs/pretrained.py | 98 ++
demucs/py.typed | 0
demucs/repitch.py | 87 ++
demucs/repo.py | 166 ++++
demucs/separate.py | 228 +++++
demucs/solver.py | 405 +++++++++
demucs/spec.py | 47 +
demucs/states.py | 163 ++++
demucs/svd.py | 83 ++
demucs/train.py | 252 ++++++
demucs/transformer.py | 839 ++++++++++++++++++
demucs/utils.py | 149 ++++
demucs/wav.py | 255 ++++++
demucs/wdemucs.py | 9 +
hatch_build.py | 54 ++
pyproject.toml | 5 +-
pytest.ini | 2 +-
.../io.github.chidiwilliams.Buzz.metainfo.xml | 6 +-
snap/snapcraft.yaml | 2 +-
...scription_viewer_widget_additional_test.py | 35 +-
uv.lock | 60 +-
52 files changed, 6662 insertions(+), 90 deletions(-)
create mode 100644 demucs/Readme.md
create mode 100644 demucs/__init__.py
create mode 100644 demucs/__main__.py
create mode 100644 demucs/api.py
create mode 100644 demucs/apply.py
create mode 100644 demucs/audio.py
create mode 100644 demucs/audio_legacy.py
create mode 100644 demucs/augment.py
create mode 100644 demucs/demucs.py
create mode 100644 demucs/distrib.py
create mode 100644 demucs/ema.py
create mode 100755 demucs/evaluate.py
create mode 100644 demucs/grids/__init__.py
create mode 100644 demucs/grids/_explorers.py
create mode 100644 demucs/grids/mdx.py
create mode 100644 demucs/grids/mdx_extra.py
create mode 100644 demucs/grids/mdx_refine.py
create mode 100644 demucs/grids/mmi.py
create mode 100644 demucs/grids/mmi_ft.py
create mode 100644 demucs/grids/repro.py
create mode 100644 demucs/grids/repro_ft.py
create mode 100644 demucs/grids/sdx23.py
create mode 100644 demucs/hdemucs.py
create mode 100644 demucs/htdemucs.py
create mode 100644 demucs/pretrained.py
create mode 100644 demucs/py.typed
create mode 100644 demucs/repitch.py
create mode 100644 demucs/repo.py
create mode 100644 demucs/separate.py
create mode 100644 demucs/solver.py
create mode 100644 demucs/spec.py
create mode 100644 demucs/states.py
create mode 100644 demucs/svd.py
create mode 100644 demucs/train.py
create mode 100644 demucs/transformer.py
create mode 100755 demucs/utils.py
create mode 100644 demucs/wav.py
create mode 100644 demucs/wdemucs.py
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 56c80943..dbfa02f0 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -358,6 +358,7 @@ jobs:
files: |
Buzz*-unix.tar.gz
Buzz*-windows.exe
+ Buzz*-windows-*.bin
Buzz*-mac.dmg
deploy_brew_cask:
@@ -371,6 +372,13 @@ jobs:
with:
submodules: recursive
+ # Should be removed with next update to whisper.cpp
+ - name: Downgrade Xcode
+ uses: maxim-lobanov/setup-xcode@v1
+ with:
+ xcode-version: '16.0.0'
+ if: matrix.os == 'macos-latest'
+
- name: Install uv
uses: astral-sh/setup-uv@v6
diff --git a/Buzz.spec b/Buzz.spec
index 2c6fb968..0f4e8edb 100644
--- a/Buzz.spec
+++ b/Buzz.spec
@@ -13,7 +13,6 @@ datas += collect_data_files("torch")
datas += collect_data_files("demucs")
datas += copy_metadata("tqdm")
datas += copy_metadata("torch")
-datas += copy_metadata("demucs")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
diff --git a/Makefile b/Makefile
index 859d4b88..9b4050ef 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,4 @@
-version := 1.3.2
-version_escaped := $$(echo ${version} | sed -e 's/\./\\./g')
+version := 1.3.3
mac_app_path := ./dist/Buzz.app
mac_zip_path := ./dist/Buzz-${version}-mac.zip
diff --git a/buzz/__version__.py b/buzz/__version__.py
index 3b734b24..e371c8ac 100644
--- a/buzz/__version__.py
+++ b/buzz/__version__.py
@@ -1 +1 @@
-VERSION = "1.3.2"
+VERSION = "1.3.3"
diff --git a/buzz/widgets/main_window.py b/buzz/widgets/main_window.py
index 0ca97cd0..8c605f94 100644
--- a/buzz/widgets/main_window.py
+++ b/buzz/widgets/main_window.py
@@ -425,7 +425,10 @@ class MainWindow(QMainWindow):
self.transcriber_worker.stop()
self.transcriber_thread.quit()
- self.transcriber_thread.wait(5000) # Wait up to 5 seconds
+ # Only wait if thread is actually running
+ if self.transcriber_thread.isRunning():
+ if not self.transcriber_thread.wait(5000): # Wait up to 5 seconds
+ logging.warning("Transcriber thread did not finish within timeout")
if self.transcription_viewer_widget is not None:
self.transcription_viewer_widget.close()
diff --git a/buzz/widgets/recording_transcriber_widget.py b/buzz/widgets/recording_transcriber_widget.py
index 80ae166d..b036fa03 100644
--- a/buzz/widgets/recording_transcriber_widget.py
+++ b/buzz/widgets/recording_transcriber_widget.py
@@ -626,7 +626,10 @@ class RecordingTranscriberWidget(QWidget):
if self.translation_thread is not None:
self.translation_thread.quit()
- self.translation_thread.wait(35_000) # Wait up to 35 seconds
+ # Only wait if thread is actually running
+ if self.translation_thread.isRunning():
+ if not self.translation_thread.wait(45_000):
+ logging.warning("Translation thread did not finish within timeout")
self.settings.set_value(
Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE,
diff --git a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
index ba53226a..bf4400b3 100644
--- a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
+++ b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
@@ -1348,7 +1348,11 @@ class TranscriptionViewerWidget(QWidget):
self.translator.stop()
self.translation_thread.quit()
- self.translation_thread.wait(35_000) # Wait up to 35 seconds, translation thread also has timeouts, wait longer
+
+ # Only wait if thread is actually running
+ if self.translation_thread.isRunning():
+ if not self.translation_thread.wait(45_000):
+ logging.warning("Translation thread did not finish within timeout")
super().closeEvent(event)
diff --git a/demucs/Readme.md b/demucs/Readme.md
new file mode 100644
index 00000000..402d2b4a
--- /dev/null
+++ b/demucs/Readme.md
@@ -0,0 +1 @@
+Inlined demucs https://github.com/adefossez/demucs
\ No newline at end of file
diff --git a/demucs/__init__.py b/demucs/__init__.py
new file mode 100644
index 00000000..3bf9f708
--- /dev/null
+++ b/demucs/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+__version__ = "4.1.0a3"
diff --git a/demucs/__main__.py b/demucs/__main__.py
new file mode 100644
index 00000000..da0a5410
--- /dev/null
+++ b/demucs/__main__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .separate import main
+
+if __name__ == '__main__':
+ main()
diff --git a/demucs/api.py b/demucs/api.py
new file mode 100644
index 00000000..ee8a5126
--- /dev/null
+++ b/demucs/api.py
@@ -0,0 +1,393 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""API methods for demucs
+
+Classes
+-------
+`demucs.api.Separator`: The base separator class
+
+Functions
+---------
+`demucs.api.save_audio`: Save an audio
+`demucs.api.list_models`: Get models list
+
+Examples
+--------
+See the end of this module (if __name__ == "__main__")
+"""
+
+import subprocess
+
+from . import audio_legacy
+import torch as th
+import torchaudio as ta
+
+from dora.log import fatal
+from pathlib import Path
+from typing import Optional, Callable, Dict, Tuple, Union
+
+from .apply import apply_model, _replace_dict
+from .audio import AudioFile, convert_audio, save_audio
+from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT
+from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo
+
+
+class LoadAudioError(Exception):
+ pass
+
+
+class LoadModelError(Exception):
+ pass
+
+
+class _NotProvided:
+ pass
+
+
+NotProvided = _NotProvided()
+
+
+class Separator:
+ def __init__(
+ self,
+ model: str = "htdemucs",
+ repo: Optional[Path] = None,
+ device: str = "cuda" if th.cuda.is_available() else "cpu",
+ shifts: int = 1,
+ overlap: float = 0.25,
+ split: bool = True,
+ segment: Optional[int] = None,
+ jobs: int = 0,
+ progress: bool = False,
+ callback: Optional[Callable[[dict], None]] = None,
+ callback_arg: Optional[dict] = None,
+ ):
+ """
+ `class Separator`
+ =================
+
+ Parameters
+ ----------
+ model: Pretrained model name or signature. Default is htdemucs.
+ repo: Folder containing all pre-trained models for use.
+ segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
+ not specified, will use the command line option.
+ shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
+ apply the oppositve shift to the output. This is repeated `shifts` time and all \
+ predictions are averaged. This effectively makes the model time equivariant and \
+ improves SDR by up to 0.2 points. If not specified, will use the command line option.
+ split: If True, the input will be broken down into small chunks (length set by `segment`) \
+ and predictions will be performed individually on each and concatenated. Useful for \
+ model with large memory footprint like Tasnet. If not specified, will use the command \
+ line option.
+ overlap: The overlap between the splits. If not specified, will use the command line \
+ option.
+ device (torch.device, str, or None): If provided, device on which to execute the \
+ computation, otherwise `wav.device` is assumed. When `device` is different from \
+ `wav.device`, only local computations will be on `device`, while the entire tracks \
+ will be stored on `wav.device`. If not specified, will use the command line option.
+ jobs: Number of jobs. This can increase memory usage but will be much faster when \
+ multiple cores are available. If not specified, will use the command line option.
+ callback: A function will be called when the separation of a chunk starts or finished. \
+ The argument passed to the function will be a dict. For more information, please see \
+ the Callback section.
+ callback_arg: A dict containing private parameters to be passed to callback function. For \
+ more information, please see the Callback section.
+ progress: If true, show a progress bar.
+
+ Callback
+ --------
+ The function will be called with only one positional parameter whose type is `dict`. The
+ `callback_arg` will be combined with information of current separation progress. The
+ progress information will override the values in `callback_arg` if same key has been used.
+ To abort the separation, raise `KeyboardInterrupt`.
+
+ Progress information contains several keys (These keys will always exist):
+ - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
+ - `shift_idx`: The index of shifts. Starts from 0.
+ - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
+ mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
+ - `state`: Could be `"start"` or `"end"`.
+ - `audio_length`: Length of the audio (in "frame" of the tensor).
+ - `models`: Count of submodels in the model.
+ """
+ self._name = model
+ self._repo = repo
+ self._load_model()
+ self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split,
+ segment=segment, jobs=jobs, progress=progress, callback=callback,
+ callback_arg=callback_arg)
+
+ def update_parameter(
+ self,
+ device: Union[str, _NotProvided] = NotProvided,
+ shifts: Union[int, _NotProvided] = NotProvided,
+ overlap: Union[float, _NotProvided] = NotProvided,
+ split: Union[bool, _NotProvided] = NotProvided,
+ segment: Optional[Union[int, _NotProvided]] = NotProvided,
+ jobs: Union[int, _NotProvided] = NotProvided,
+ progress: Union[bool, _NotProvided] = NotProvided,
+ callback: Optional[
+ Union[Callable[[dict], None], _NotProvided]
+ ] = NotProvided,
+ callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided,
+ ):
+ """
+ Update the parameters of separation.
+
+ Parameters
+ ----------
+ segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
+ not specified, will use the command line option.
+ shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
+ apply the oppositve shift to the output. This is repeated `shifts` time and all \
+ predictions are averaged. This effectively makes the model time equivariant and \
+ improves SDR by up to 0.2 points. If not specified, will use the command line option.
+ split: If True, the input will be broken down into small chunks (length set by `segment`) \
+ and predictions will be performed individually on each and concatenated. Useful for \
+ model with large memory footprint like Tasnet. If not specified, will use the command \
+ line option.
+ overlap: The overlap between the splits. If not specified, will use the command line \
+ option.
+ device (torch.device, str, or None): If provided, device on which to execute the \
+ computation, otherwise `wav.device` is assumed. When `device` is different from \
+ `wav.device`, only local computations will be on `device`, while the entire tracks \
+ will be stored on `wav.device`. If not specified, will use the command line option.
+ jobs: Number of jobs. This can increase memory usage but will be much faster when \
+ multiple cores are available. If not specified, will use the command line option.
+ callback: A function will be called when the separation of a chunk starts or finished. \
+ The argument passed to the function will be a dict. For more information, please see \
+ the Callback section.
+ callback_arg: A dict containing private parameters to be passed to callback function. For \
+ more information, please see the Callback section.
+ progress: If true, show a progress bar.
+
+ Callback
+ --------
+ The function will be called with only one positional parameter whose type is `dict`. The
+ `callback_arg` will be combined with information of current separation progress. The
+ progress information will override the values in `callback_arg` if same key has been used.
+ To abort the separation, raise `KeyboardInterrupt`.
+
+ Progress information contains several keys (These keys will always exist):
+ - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
+ - `shift_idx`: The index of shifts. Starts from 0.
+ - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
+ mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
+ - `state`: Could be `"start"` or `"end"`.
+ - `audio_length`: Length of the audio (in "frame" of the tensor).
+ - `models`: Count of submodels in the model.
+ """
+ if not isinstance(device, _NotProvided):
+ self._device = device
+ if not isinstance(shifts, _NotProvided):
+ self._shifts = shifts
+ if not isinstance(overlap, _NotProvided):
+ self._overlap = overlap
+ if not isinstance(split, _NotProvided):
+ self._split = split
+ if not isinstance(segment, _NotProvided):
+ self._segment = segment
+ if not isinstance(jobs, _NotProvided):
+ self._jobs = jobs
+ if not isinstance(progress, _NotProvided):
+ self._progress = progress
+ if not isinstance(callback, _NotProvided):
+ self._callback = callback
+ if not isinstance(callback_arg, _NotProvided):
+ self._callback_arg = callback_arg
+
+ def _load_model(self):
+ self._model = get_model(name=self._name, repo=self._repo)
+ if self._model is None:
+ raise LoadModelError("Failed to load model")
+ self._audio_channels = self._model.audio_channels
+ self._samplerate = self._model.samplerate
+
+ def _load_audio(self, track: Path):
+ errors = {}
+ wav = None
+
+ try:
+ wav = AudioFile(track).read(streams=0, samplerate=self._samplerate,
+ channels=self._audio_channels)
+ except FileNotFoundError:
+ errors["ffmpeg"] = "FFmpeg is not installed."
+ except subprocess.CalledProcessError:
+ errors["ffmpeg"] = "FFmpeg could not read the file."
+
+ if wav is None:
+ try:
+ wav, sr = ta.load(str(track))
+ except RuntimeError as err:
+ errors["torchaudio"] = err.args[0]
+ else:
+ wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
+
+ if wav is None:
+ raise LoadAudioError(
+ "\n".join(
+ "When trying to load using {}, got the following error: {}".format(
+ backend, error
+ )
+ for backend, error in errors.items()
+ )
+ )
+ return wav
+
+ def separate_tensor(
+ self, wav: th.Tensor, sr: Optional[int] = None
+ ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]:
+ """
+ Separate a loaded tensor.
+
+ Parameters
+ ----------
+ wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \
+ while the second is the waveform of each channel. Type should be float32. \
+ e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels.
+ sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \
+ model.
+
+ Returns
+ -------
+ A tuple, whose first element is the original wave and second element is a dict, whose keys
+ are the name of stems and values are separated waves. The original wave will have already
+ been resampled.
+
+ Notes
+ -----
+ Use this function with cautiousness. This function does not provide data verifying.
+ """
+ if sr is not None and sr != self.samplerate:
+ wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
+ ref = wav.mean(0)
+ wav -= ref.mean()
+ wav /= ref.std() + 1e-8
+ out = apply_model(
+ self._model,
+ wav[None],
+ segment=self._segment,
+ shifts=self._shifts,
+ split=self._split,
+ overlap=self._overlap,
+ device=self._device,
+ num_workers=self._jobs,
+ callback=self._callback,
+ callback_arg=_replace_dict(
+ self._callback_arg, ("audio_length", wav.shape[1])
+ ),
+ progress=self._progress,
+ )
+ if out is None:
+ raise KeyboardInterrupt
+ out *= ref.std() + 1e-8
+ out += ref.mean()
+ wav *= ref.std() + 1e-8
+ wav += ref.mean()
+ return (wav, dict(zip(self._model.sources, out[0])))
+
+ def separate_audio_file(self, file: Path):
+ """
+ Separate an audio file. The method will automatically read the file.
+
+ Parameters
+ ----------
+ wav: Path of the file to be separated.
+
+ Returns
+ -------
+ A tuple, whose first element is the original wave and second element is a dict, whose keys
+ are the name of stems and values are separated waves. The original wave will have already
+ been resampled.
+ """
+ return self.separate_tensor(self._load_audio(file), self.samplerate)
+
+ @property
+ def samplerate(self):
+ return self._samplerate
+
+ @property
+ def audio_channels(self):
+ return self._audio_channels
+
+ @property
+ def model(self):
+ return self._model
+
+
+def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]:
+ """
+ List the available models. Please remember that not all the returned models can be
+ successfully loaded.
+
+ Parameters
+ ----------
+ repo: The repo whose models are to be listed.
+
+ Returns
+ -------
+ A dict with two keys ("single" for single models and "bag" for bag of models). The values are
+ lists whose components are strs.
+ """
+ model_repo: ModelOnlyRepo
+ if repo is None:
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
+ model_repo = RemoteRepo(models)
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
+ else:
+ if not repo.is_dir():
+ fatal(f"{repo} must exist and be a directory.")
+ model_repo = LocalRepo(repo)
+ bag_repo = BagOnlyRepo(repo, model_repo)
+ return {"single": model_repo.list_model(), "bag": bag_repo.list_model()}
+
+
+if __name__ == "__main__":
+ # Test API functions
+ # two-stem not supported
+
+ from .separate import get_parser
+
+ args = get_parser().parse_args()
+ separator = Separator(
+ model=args.name,
+ repo=args.repo,
+ device=args.device,
+ shifts=args.shifts,
+ overlap=args.overlap,
+ split=args.split,
+ segment=args.segment,
+ jobs=args.jobs,
+ callback=print
+ )
+ out = args.out / args.name
+ out.mkdir(parents=True, exist_ok=True)
+ for file in args.tracks:
+ separated = separator.separate_audio_file(file)[1]
+ if args.mp3:
+ ext = "mp3"
+ elif args.flac:
+ ext = "flac"
+ else:
+ ext = "wav"
+ kwargs = {
+ "samplerate": separator.samplerate,
+ "bitrate": args.mp3_bitrate,
+ "clip": args.clip_mode,
+ "as_float": args.float32,
+ "bits_per_sample": 24 if args.int24 else 16,
+ }
+ for stem, source in separated.items():
+ stem = out / args.filename.format(
+ track=Path(file).name.rsplit(".", 1)[0],
+ trackext=Path(file).name.rsplit(".", 1)[-1],
+ stem=stem,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(source, str(stem), **kwargs)
diff --git a/demucs/apply.py b/demucs/apply.py
new file mode 100644
index 00000000..c84993de
--- /dev/null
+++ b/demucs/apply.py
@@ -0,0 +1,322 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Code to apply a model to a mix. It will handle chunking with overlaps and
+inteprolation between chunks, as well as the "shift trick".
+"""
+from concurrent.futures import ThreadPoolExecutor
+import copy
+import random
+from threading import Lock
+import typing as tp
+
+import torch as th
+from torch import nn
+from torch.nn import functional as F
+import tqdm
+
+from .demucs import Demucs
+from .hdemucs import HDemucs
+from .htdemucs import HTDemucs
+from .utils import center_trim, DummyPoolExecutor
+
+Model = tp.Union[Demucs, HDemucs, HTDemucs]
+
+
+class BagOfModels(nn.Module):
+ def __init__(self, models: tp.List[Model],
+ weights: tp.Optional[tp.List[tp.List[float]]] = None,
+ segment: tp.Optional[float] = None):
+ """
+ Represents a bag of models with specific weights.
+ You should call `apply_model` rather than calling directly the forward here for
+ optimal performance.
+
+ Args:
+ models (list[nn.Module]): list of Demucs/HDemucs models.
+ weights (list[list[float]]): list of weights. If None, assumed to
+ be all ones, otherwise it should be a list of N list (N number of models),
+ each containing S floats (S number of sources).
+ segment (None or float): overrides the `segment` attribute of each model
+ (this is performed inplace, be careful is you reuse the models passed).
+ """
+ super().__init__()
+ assert len(models) > 0
+ first = models[0]
+ for other in models:
+ assert other.sources == first.sources
+ assert other.samplerate == first.samplerate
+ assert other.audio_channels == first.audio_channels
+ if segment is not None:
+ if not isinstance(other, HTDemucs) or segment <= other.segment:
+ other.segment = segment
+
+ self.audio_channels = first.audio_channels
+ self.samplerate = first.samplerate
+ self.sources = first.sources
+ self.models = nn.ModuleList(models)
+
+ if weights is None:
+ weights = [[1. for _ in first.sources] for _ in models]
+ else:
+ assert len(weights) == len(models)
+ for weight in weights:
+ assert len(weight) == len(first.sources)
+ self.weights = weights
+
+ @property
+ def max_allowed_segment(self) -> float:
+ max_allowed_segment = float('inf')
+ for model in self.models:
+ if isinstance(model, HTDemucs):
+ max_allowed_segment = min(max_allowed_segment, float(model.segment))
+ return max_allowed_segment
+
+ def forward(self, x):
+ raise NotImplementedError("Call `apply_model` on this.")
+
+
+class TensorChunk:
+ def __init__(self, tensor, offset=0, length=None):
+ total_length = tensor.shape[-1]
+ assert offset >= 0
+ assert offset < total_length
+
+ if length is None:
+ length = total_length - offset
+ else:
+ length = min(total_length - offset, length)
+
+ if isinstance(tensor, TensorChunk):
+ self.tensor = tensor.tensor
+ self.offset = offset + tensor.offset
+ else:
+ self.tensor = tensor
+ self.offset = offset
+ self.length = length
+ self.device = tensor.device
+
+ @property
+ def shape(self):
+ shape = list(self.tensor.shape)
+ shape[-1] = self.length
+ return shape
+
+ def padded(self, target_length):
+ delta = target_length - self.length
+ total_length = self.tensor.shape[-1]
+ assert delta >= 0
+
+ start = self.offset - delta // 2
+ end = start + target_length
+
+ correct_start = max(0, start)
+ correct_end = min(total_length, end)
+
+ pad_left = correct_start - start
+ pad_right = end - correct_end
+
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
+ assert out.shape[-1] == target_length
+ return out
+
+
+def tensor_chunk(tensor_or_chunk):
+ if isinstance(tensor_or_chunk, TensorChunk):
+ return tensor_or_chunk
+ else:
+ assert isinstance(tensor_or_chunk, th.Tensor)
+ return TensorChunk(tensor_or_chunk)
+
+
+def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict:
+ if _dict is None:
+ _dict = {}
+ else:
+ _dict = copy.copy(_dict)
+ for key, value in subs:
+ _dict[key] = value
+ return _dict
+
+
+def apply_model(model: tp.Union[BagOfModels, Model],
+ mix: tp.Union[th.Tensor, TensorChunk],
+ shifts: int = 1, split: bool = True,
+ overlap: float = 0.25, transition_power: float = 1.,
+ progress: bool = False, device=None,
+ num_workers: int = 0, segment: tp.Optional[float] = None,
+ pool=None, lock=None,
+ callback: tp.Optional[tp.Callable[[dict], None]] = None,
+ callback_arg: tp.Optional[dict] = None) -> th.Tensor:
+ """
+ Apply model to a given mixture.
+
+ Args:
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
+ all predictions are averaged. This effectively makes the model time equivariant
+ and improves SDR by up to 0.2 points.
+ split (bool): if True, the input will be broken down in 8 seconds extracts
+ and predictions will be performed individually on each and concatenated.
+ Useful for model with large memory footprint like Tasnet.
+ progress (bool): if True, show a progress bar (requires split=True)
+ device (torch.device, str, or None): if provided, device on which to
+ execute the computation, otherwise `mix.device` is assumed.
+ When `device` is different from `mix.device`, only local computations will
+ be on `device`, while the entire tracks will be stored on `mix.device`.
+ num_workers (int): if non zero, device is 'cpu', how many threads to
+ use in parallel.
+ segment (float or None): override the model segment parameter.
+ """
+ if device is None:
+ device = mix.device
+ else:
+ device = th.device(device)
+ if pool is None:
+ if num_workers > 0 and device.type == 'cpu':
+ pool = ThreadPoolExecutor(num_workers)
+ else:
+ pool = DummyPoolExecutor()
+ if lock is None:
+ lock = Lock()
+ callback_arg = _replace_dict(
+ callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items()
+ )
+ kwargs: tp.Dict[str, tp.Any] = {
+ 'shifts': shifts,
+ 'split': split,
+ 'overlap': overlap,
+ 'transition_power': transition_power,
+ 'progress': progress,
+ 'device': device,
+ 'pool': pool,
+ 'segment': segment,
+ 'lock': lock,
+ }
+ out: tp.Union[float, th.Tensor]
+ res: tp.Union[float, th.Tensor]
+ if isinstance(model, BagOfModels):
+ # Special treatment for bag of model.
+ # We explicitely apply multiple times `apply_model` so that the random shifts
+ # are different for each model.
+ estimates: tp.Union[float, th.Tensor] = 0.
+ totals = [0.] * len(model.sources)
+ callback_arg["models"] = len(model.models)
+ for sub_model, model_weights in zip(model.models, model.weights):
+ kwargs["callback"] = ((
+ lambda d, i=callback_arg["model_idx_in_bag"]: callback(
+ _replace_dict(d, ("model_idx_in_bag", i))) if callback else None)
+ )
+ original_model_device = next(iter(sub_model.parameters())).device
+ sub_model.to(device)
+
+ res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg)
+ out = res
+ sub_model.to(original_model_device)
+ for k, inst_weight in enumerate(model_weights):
+ out[:, k, :, :] *= inst_weight
+ totals[k] += inst_weight
+ estimates += out
+ del out
+ callback_arg["model_idx_in_bag"] += 1
+
+ assert isinstance(estimates, th.Tensor)
+ for k in range(estimates.shape[1]):
+ estimates[:, k, :, :] /= totals[k]
+ return estimates
+
+ if "models" not in callback_arg:
+ callback_arg["models"] = 1
+ model.to(device)
+ model.eval()
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
+ batch, channels, length = mix.shape
+ if shifts:
+ kwargs['shifts'] = 0
+ max_shift = int(0.5 * model.samplerate)
+ mix = tensor_chunk(mix)
+ assert isinstance(mix, TensorChunk)
+ padded_mix = mix.padded(length + 2 * max_shift)
+ out = 0.
+ for shift_idx in range(shifts):
+ offset = random.randint(0, max_shift)
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
+ kwargs["callback"] = (
+ (lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i)))
+ if callback else None)
+ )
+ res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg)
+ shifted_out = res
+ out += shifted_out[..., max_shift - offset:]
+ out /= shifts
+ assert isinstance(out, th.Tensor)
+ return out
+ elif split:
+ kwargs['split'] = False
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
+ sum_weight = th.zeros(length, device=mix.device)
+ if segment is None:
+ segment = model.segment
+ assert segment is not None and segment > 0.
+ segment_length: int = int(model.samplerate * segment)
+ stride = int((1 - overlap) * segment_length)
+ offsets = range(0, length, stride)
+ scale = float(format(stride / model.samplerate, ".2f"))
+ # We start from a triangle shaped weight, with maximal weight in the middle
+ # of the segment. Then we normalize and take to the power `transition_power`.
+ # Large values of transition power will lead to sharper transitions.
+ weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
+ th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
+ assert len(weight) == segment_length
+ # If the overlap < 50%, this will translate to linear transition when
+ # transition_power is 1.
+ weight = (weight / weight.max())**transition_power
+ futures = []
+ for offset in offsets:
+ chunk = TensorChunk(mix, offset, segment_length)
+ future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg,
+ callback=(lambda d, i=offset:
+ callback(_replace_dict(d, ("segment_offset", i)))
+ if callback else None))
+ futures.append((future, offset))
+ offset += segment_length
+ if progress:
+ futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
+ for future, offset in futures:
+ try:
+ chunk_out = future.result() # type: th.Tensor
+ except Exception:
+ pool.shutdown(wait=True, cancel_futures=True)
+ raise
+ chunk_length = chunk_out.shape[-1]
+ out[..., offset:offset + segment_length] += (
+ weight[:chunk_length] * chunk_out).to(mix.device)
+ sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
+ assert sum_weight.min() > 0
+ out /= sum_weight
+ assert isinstance(out, th.Tensor)
+ return out
+ else:
+ valid_length: int
+ if isinstance(model, HTDemucs) and segment is not None:
+ valid_length = int(segment * model.samplerate)
+ elif hasattr(model, 'valid_length'):
+ valid_length = model.valid_length(length) # type: ignore
+ else:
+ valid_length = length
+ mix = tensor_chunk(mix)
+ assert isinstance(mix, TensorChunk)
+ padded_mix = mix.padded(valid_length).to(device)
+ with lock:
+ if callback is not None:
+ callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore
+ with th.no_grad():
+ out = model(padded_mix)
+ with lock:
+ if callback is not None:
+ callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore
+ assert isinstance(out, th.Tensor)
+ return center_trim(out, length)
diff --git a/demucs/audio.py b/demucs/audio.py
new file mode 100644
index 00000000..600bd55b
--- /dev/null
+++ b/demucs/audio.py
@@ -0,0 +1,266 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import json
+import subprocess as sp
+from pathlib import Path
+
+import lameenc
+import julius
+import numpy as np
+from . import audio_legacy
+import torch
+import torchaudio as ta
+import typing as tp
+
+from .utils import temp_filenames
+
+
+def _read_info(path):
+ stdout_data = sp.check_output([
+ 'ffprobe', "-loglevel", "panic",
+ str(path), '-print_format', 'json', '-show_format', '-show_streams'
+ ])
+ return json.loads(stdout_data.decode('utf-8'))
+
+
+class AudioFile:
+ """
+ Allows to read audio from any format supported by ffmpeg, as well as resampling or
+ converting to mono on the fly. See :method:`read` for more details.
+ """
+ def __init__(self, path: Path):
+ self.path = Path(path)
+ self._info = None
+
+ def __repr__(self):
+ features = [("path", self.path)]
+ features.append(("samplerate", self.samplerate()))
+ features.append(("channels", self.channels()))
+ features.append(("streams", len(self)))
+ features_str = ", ".join(f"{name}={value}" for name, value in features)
+ return f"AudioFile({features_str})"
+
+ @property
+ def info(self):
+ if self._info is None:
+ self._info = _read_info(self.path)
+ return self._info
+
+ @property
+ def duration(self):
+ return float(self.info['format']['duration'])
+
+ @property
+ def _audio_streams(self):
+ return [
+ index for index, stream in enumerate(self.info["streams"])
+ if stream["codec_type"] == "audio"
+ ]
+
+ def __len__(self):
+ return len(self._audio_streams)
+
+ def channels(self, stream=0):
+ return int(self.info['streams'][self._audio_streams[stream]]['channels'])
+
+ def samplerate(self, stream=0):
+ return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
+
+ def read(self,
+ seek_time=None,
+ duration=None,
+ streams=slice(None),
+ samplerate=None,
+ channels=None):
+ """
+ Slightly more efficient implementation than stempeg,
+ in particular, this will extract all stems at once
+ rather than having to loop over one file multiple times
+ for each stream.
+
+ Args:
+ seek_time (float): seek time in seconds or None if no seeking is needed.
+ duration (float): duration in seconds to extract or None to extract until the end.
+ streams (slice, int or list): streams to extract, can be a single int, a list or
+ a slice. If it is a slice or list, the output will be of size [S, C, T]
+ with S the number of streams, C the number of channels and T the number of samples.
+ If it is an int, the output will be [C, T].
+ samplerate (int): if provided, will resample on the fly. If None, no resampling will
+ be done. Original sampling rate can be obtained with :method:`samplerate`.
+ channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
+ as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
+ See https://sound.stackexchange.com/a/42710.
+ Our definition of mono is simply the average of the two channels. Any other
+ value will be ignored.
+ """
+ streams = np.array(range(len(self)))[streams]
+ single = not isinstance(streams, np.ndarray)
+ if single:
+ streams = [streams]
+
+ if duration is None:
+ target_size = None
+ query_duration = None
+ else:
+ target_size = int((samplerate or self.samplerate()) * duration)
+ query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
+
+ with temp_filenames(len(streams)) as filenames:
+ command = ['ffmpeg', '-y']
+ command += ['-loglevel', 'panic']
+ if seek_time:
+ command += ['-ss', str(seek_time)]
+ command += ['-i', str(self.path)]
+ for stream, filename in zip(streams, filenames):
+ command += ['-map', f'0:{self._audio_streams[stream]}']
+ if query_duration is not None:
+ command += ['-t', str(query_duration)]
+ command += ['-threads', '1']
+ command += ['-f', 'f32le']
+ if samplerate is not None:
+ command += ['-ar', str(samplerate)]
+ command += [filename]
+
+ sp.run(command, check=True)
+ wavs = []
+ for filename in filenames:
+ wav = np.fromfile(filename, dtype=np.float32)
+ wav = torch.from_numpy(wav)
+ wav = wav.view(-1, self.channels()).t()
+ if channels is not None:
+ wav = convert_audio_channels(wav, channels)
+ if target_size is not None:
+ wav = wav[..., :target_size]
+ wavs.append(wav)
+ wav = torch.stack(wavs, dim=0)
+ if single:
+ wav = wav[0]
+ return wav
+
+
+def convert_audio_channels(wav, channels=2):
+ """Convert audio to the given number of channels."""
+ *shape, src_channels, length = wav.shape
+ if src_channels == channels:
+ pass
+ elif channels == 1:
+ # Case 1:
+ # The caller asked 1-channel audio, but the stream have multiple
+ # channels, downmix all channels.
+ wav = wav.mean(dim=-2, keepdim=True)
+ elif src_channels == 1:
+ # Case 2:
+ # The caller asked for multiple channels, but the input file have
+ # one single channel, replicate the audio over all channels.
+ wav = wav.expand(*shape, channels, length)
+ elif src_channels >= channels:
+ # Case 3:
+ # The caller asked for multiple channels, and the input file have
+ # more channels than requested. In that case return the first channels.
+ wav = wav[..., :channels, :]
+ else:
+ # Case 4: What is a reasonable choice here?
+ raise ValueError('The audio file has less channels than requested but is not mono.')
+ return wav
+
+
+def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor:
+ """Convert audio from a given samplerate to a target one and target number of channels."""
+ wav = convert_audio_channels(wav, channels)
+ return julius.resample_frac(wav, from_samplerate, to_samplerate)
+
+
+def i16_pcm(wav):
+ """Convert audio to 16 bits integer PCM format."""
+ if wav.dtype.is_floating_point:
+ return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
+ else:
+ return wav
+
+
+def f32_pcm(wav):
+ """Convert audio to float 32 bits PCM format."""
+ if wav.dtype.is_floating_point:
+ return wav
+ else:
+ return wav.float() / (2**15 - 1)
+
+
+def as_dtype_pcm(wav, dtype):
+ """Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
+ if wav.dtype.is_floating_point:
+ return f32_pcm(wav)
+ else:
+ return i16_pcm(wav)
+
+
+def encode_mp3(wav, path, samplerate=44100, bitrate=320, quality=2, verbose=False):
+ """Save given audio as mp3. This should work on all OSes."""
+ C, T = wav.shape
+ wav = i16_pcm(wav)
+ encoder = lameenc.Encoder()
+ encoder.set_bit_rate(bitrate)
+ encoder.set_in_sample_rate(samplerate)
+ encoder.set_channels(C)
+ encoder.set_quality(quality) # 2-highest, 7-fastest
+ if not verbose:
+ encoder.silence()
+ wav = wav.data.cpu()
+ wav = wav.transpose(0, 1).numpy()
+ mp3_data = encoder.encode(wav.tobytes())
+ mp3_data += encoder.flush()
+ with open(path, "wb") as f:
+ f.write(mp3_data)
+
+
+def prevent_clip(wav, mode='rescale'):
+ """
+ different strategies for avoiding raw clipping.
+ """
+ if mode is None or mode == 'none':
+ return wav
+ assert wav.dtype.is_floating_point, "too late for clipping"
+ if mode == 'rescale':
+ wav = wav / max(1.01 * wav.abs().max(), 1)
+ elif mode == 'clamp':
+ wav = wav.clamp(-0.99, 0.99)
+ elif mode == 'tanh':
+ wav = torch.tanh(wav)
+ else:
+ raise ValueError(f"Invalid mode {mode}")
+ return wav
+
+
+def save_audio(wav: torch.Tensor,
+ path: tp.Union[str, Path],
+ samplerate: int,
+ bitrate: int = 320,
+ clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale',
+ bits_per_sample: tp.Literal[16, 24, 32] = 16,
+ as_float: bool = False,
+ preset: tp.Literal[2, 3, 4, 5, 6, 7] = 2):
+ """Save audio file, automatically preventing clipping if necessary
+ based on the given `clip` strategy. If the path ends in `.mp3`, this
+ will save as mp3 with the given `bitrate`. Use `preset` to set mp3 quality:
+ 2 for highest quality, 7 for fastest speed
+ """
+ wav = prevent_clip(wav, mode=clip)
+ path = Path(path)
+ suffix = path.suffix.lower()
+ if suffix == ".mp3":
+ encode_mp3(wav, path, samplerate, bitrate, preset, verbose=True)
+ elif suffix == ".wav":
+ if as_float:
+ bits_per_sample = 32
+ encoding = 'PCM_F'
+ else:
+ encoding = 'PCM_S'
+ ta.save(str(path), wav, sample_rate=samplerate,
+ encoding=encoding, bits_per_sample=bits_per_sample)
+ elif suffix == ".flac":
+ ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
+ else:
+ raise ValueError(f"Invalid suffix for path: {suffix}")
diff --git a/demucs/audio_legacy.py b/demucs/audio_legacy.py
new file mode 100644
index 00000000..ab6bdce4
--- /dev/null
+++ b/demucs/audio_legacy.py
@@ -0,0 +1,17 @@
+# This file is to extend support for torchaudio 2.1
+
+import importlib
+import os
+import sys
+import warnings
+
+if not "torchaudio" in sys.modules:
+ os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "0"
+elif os.getenv("TORCHAUDIO_USE_BACKEND_DISPATCHER", default="1") == "1":
+ if sys.modules["torchaudio"].__version__ >= "2.1":
+ os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "0"
+ importlib.reload(sys.modules["torchaudio"])
+ warnings.warn(
+ "TORCHAUDIO_USE_BACKEND_DISPATCHER is set to 0 and torchaudio is reloaded.",
+ ImportWarning,
+ )
diff --git a/demucs/augment.py b/demucs/augment.py
new file mode 100644
index 00000000..6dab7f12
--- /dev/null
+++ b/demucs/augment.py
@@ -0,0 +1,111 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Data augmentations.
+"""
+
+import random
+import torch as th
+from torch import nn
+
+
+class Shift(nn.Module):
+ """
+ Randomly shift audio in time by up to `shift` samples.
+ """
+ def __init__(self, shift=8192, same=False):
+ super().__init__()
+ self.shift = shift
+ self.same = same
+
+ def forward(self, wav):
+ batch, sources, channels, time = wav.size()
+ length = time - self.shift
+ if self.shift > 0:
+ if not self.training:
+ wav = wav[..., :length]
+ else:
+ srcs = 1 if self.same else sources
+ offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device)
+ offsets = offsets.expand(-1, sources, channels, -1)
+ indexes = th.arange(length, device=wav.device)
+ wav = wav.gather(3, indexes + offsets)
+ return wav
+
+
+class FlipChannels(nn.Module):
+ """
+ Flip left-right channels.
+ """
+ def forward(self, wav):
+ batch, sources, channels, time = wav.size()
+ if self.training and wav.size(2) == 2:
+ left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
+ left = left.expand(-1, -1, -1, time)
+ right = 1 - left
+ wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
+ return wav
+
+
+class FlipSign(nn.Module):
+ """
+ Random sign flip.
+ """
+ def forward(self, wav):
+ batch, sources, channels, time = wav.size()
+ if self.training:
+ signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32)
+ wav = wav * (2 * signs - 1)
+ return wav
+
+
+class Remix(nn.Module):
+ """
+ Shuffle sources to make new mixes.
+ """
+ def __init__(self, proba=1, group_size=4):
+ """
+ Shuffle sources within one batch.
+ Each batch is divided into groups of size `group_size` and shuffling is done within
+ each group separatly. This allow to keep the same probability distribution no matter
+ the number of GPUs. Without this grouping, using more GPUs would lead to a higher
+ probability of keeping two sources from the same track together which can impact
+ performance.
+ """
+ super().__init__()
+ self.proba = proba
+ self.group_size = group_size
+
+ def forward(self, wav):
+ batch, streams, channels, time = wav.size()
+ device = wav.device
+
+ if self.training and random.random() < self.proba:
+ group_size = self.group_size or batch
+ if batch % group_size != 0:
+ raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
+ groups = batch // group_size
+ wav = wav.view(groups, group_size, streams, channels, time)
+ permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device),
+ dim=1)
+ wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time))
+ wav = wav.view(batch, streams, channels, time)
+ return wav
+
+
+class Scale(nn.Module):
+ def __init__(self, proba=1., min=0.25, max=1.25):
+ super().__init__()
+ self.proba = proba
+ self.min = min
+ self.max = max
+
+ def forward(self, wav):
+ batch, streams, channels, time = wav.size()
+ device = wav.device
+ if self.training and random.random() < self.proba:
+ scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max)
+ wav *= scales
+ return wav
diff --git a/demucs/demucs.py b/demucs/demucs.py
new file mode 100644
index 00000000..f6a4305c
--- /dev/null
+++ b/demucs/demucs.py
@@ -0,0 +1,447 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+
+import julius
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .states import capture_init
+from .utils import center_trim, unfold
+from .transformer import LayerScale
+
+
+class BLSTM(nn.Module):
+ """
+ BiLSTM with same hidden units as input dim.
+ If `max_steps` is not None, input will be splitting in overlapping
+ chunks and the LSTM applied separately on each chunk.
+ """
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
+ super().__init__()
+ assert max_steps is None or max_steps % 4 == 0
+ self.max_steps = max_steps
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
+ self.linear = nn.Linear(2 * dim, dim)
+ self.skip = skip
+
+ def forward(self, x):
+ B, C, T = x.shape
+ y = x
+ framed = False
+ if self.max_steps is not None and T > self.max_steps:
+ width = self.max_steps
+ stride = width // 2
+ frames = unfold(x, width, stride)
+ nframes = frames.shape[2]
+ framed = True
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
+
+ x = x.permute(2, 0, 1)
+
+ x = self.lstm(x)[0]
+ x = self.linear(x)
+ x = x.permute(1, 2, 0)
+ if framed:
+ out = []
+ frames = x.reshape(B, -1, C, width)
+ limit = stride // 2
+ for k in range(nframes):
+ if k == 0:
+ out.append(frames[:, k, :, :-limit])
+ elif k == nframes - 1:
+ out.append(frames[:, k, :, limit:])
+ else:
+ out.append(frames[:, k, :, limit:-limit])
+ out = torch.cat(out, -1)
+ out = out[..., :T]
+ x = out
+ if self.skip:
+ x = x + y
+ return x
+
+
+def rescale_conv(conv, reference):
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
+ """
+ std = conv.weight.std().detach()
+ scale = (std / reference)**0.5
+ conv.weight.data /= scale
+ if conv.bias is not None:
+ conv.bias.data /= scale
+
+
+def rescale_module(module, reference):
+ for sub in module.modules():
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
+ rescale_conv(sub, reference)
+
+
+class DConv(nn.Module):
+ """
+ New residual branches in each encoder layer.
+ This alternates dilated convolutions, potentially with LSTMs and attention.
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
+ e.g. of dim `channels // compress`.
+ """
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
+ kernel=3, dilate=True):
+ """
+ Args:
+ channels: input/output channels for residual branch.
+ compress: amount of channel compression inside the branch.
+ depth: number of layers in the residual branch. Each layer has its own
+ projection, and potentially LSTM and attention.
+ init: initial scale for LayerNorm.
+ norm: use GroupNorm.
+ attn: use LocalAttention.
+ heads: number of heads for the LocalAttention.
+ ndecay: number of decay controls in the LocalAttention.
+ lstm: use LSTM.
+ gelu: Use GELU activation.
+ kernel: kernel size for the (dilated) convolutions.
+ dilate: if true, use dilation, increasing with the depth.
+ """
+
+ super().__init__()
+ assert kernel % 2 == 1
+ self.channels = channels
+ self.compress = compress
+ self.depth = abs(depth)
+ dilate = depth > 0
+
+ norm_fn: tp.Callable[[int], nn.Module]
+ norm_fn = lambda d: nn.Identity() # noqa
+ if norm:
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
+
+ hidden = int(channels / compress)
+
+ act: tp.Type[nn.Module]
+ if gelu:
+ act = nn.GELU
+ else:
+ act = nn.ReLU
+
+ self.layers = nn.ModuleList([])
+ for d in range(self.depth):
+ dilation = 2 ** d if dilate else 1
+ padding = dilation * (kernel // 2)
+ mods = [
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
+ norm_fn(hidden), act(),
+ nn.Conv1d(hidden, 2 * channels, 1),
+ norm_fn(2 * channels), nn.GLU(1),
+ LayerScale(channels, init),
+ ]
+ if attn:
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
+ if lstm:
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
+ layer = nn.Sequential(*mods)
+ self.layers.append(layer)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = x + layer(x)
+ return x
+
+
+class LocalState(nn.Module):
+ """Local state allows to have attention based only on data (no positional embedding),
+ but while setting a constraint on the time window (e.g. decaying penalty term).
+
+ Also a failed experiments with trying to provide some frequency based attention.
+ """
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
+ super().__init__()
+ assert channels % heads == 0, (channels, heads)
+ self.heads = heads
+ self.nfreqs = nfreqs
+ self.ndecay = ndecay
+ self.content = nn.Conv1d(channels, channels, 1)
+ self.query = nn.Conv1d(channels, channels, 1)
+ self.key = nn.Conv1d(channels, channels, 1)
+ if nfreqs:
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
+ if ndecay:
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
+ self.query_decay.weight.data *= 0.01
+ assert self.query_decay.bias is not None # stupid type checker
+ self.query_decay.bias.data[:] = -2
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
+
+ def forward(self, x):
+ B, C, T = x.shape
+ heads = self.heads
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
+ # left index are keys, right index are queries
+ delta = indexes[:, None] - indexes[None, :]
+
+ queries = self.query(x).view(B, heads, -1, T)
+ keys = self.key(x).view(B, heads, -1, T)
+ # t are keys, s are queries
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
+ dots /= keys.shape[2]**0.5
+ if self.nfreqs:
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
+ if self.ndecay:
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
+ decay_q = torch.sigmoid(decay_q) / 2
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
+
+ # Kill self reference.
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
+ weights = torch.softmax(dots, dim=2)
+
+ content = self.content(x).view(B, heads, -1, T)
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
+ if self.nfreqs:
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
+ result = torch.cat([result, time_sig], 2)
+ result = result.reshape(B, -1, T)
+ return x + self.proj(result)
+
+
+class Demucs(nn.Module):
+ @capture_init
+ def __init__(self,
+ sources,
+ # Channels
+ audio_channels=2,
+ channels=64,
+ growth=2.,
+ # Main structure
+ depth=6,
+ rewrite=True,
+ lstm_layers=0,
+ # Convolutions
+ kernel_size=8,
+ stride=4,
+ context=1,
+ # Activations
+ gelu=True,
+ glu=True,
+ # Normalization
+ norm_starts=4,
+ norm_groups=4,
+ # DConv residual branch
+ dconv_mode=1,
+ dconv_depth=2,
+ dconv_comp=4,
+ dconv_attn=4,
+ dconv_lstm=4,
+ dconv_init=1e-4,
+ # Pre/post processing
+ normalize=True,
+ resample=True,
+ # Weight init
+ rescale=0.1,
+ # Metadata
+ samplerate=44100,
+ segment=4 * 10):
+ """
+ Args:
+ sources (list[str]): list of source names
+ audio_channels (int): stereo or mono
+ channels (int): first convolution channels
+ depth (int): number of encoder/decoder layers
+ growth (float): multiply (resp divide) number of channels by that
+ for each layer of the encoder (resp decoder)
+ depth (int): number of layers in the encoder and in the decoder.
+ rewrite (bool): add 1x1 convolution to each layer.
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
+ by default, as this is now replaced by the smaller and faster small LSTMs
+ in the DConv branches.
+ kernel_size (int): kernel size for convolutions
+ stride (int): stride for convolutions
+ context (int): kernel size of the convolution in the
+ decoder before the transposed convolution. If > 1,
+ will provide some context from neighboring time steps.
+ gelu: use GELU activation function.
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
+ norm_starts: layer at which group norm starts being used.
+ decoder layers are numbered in reverse order.
+ norm_groups: number of groups for group norm.
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
+ dconv_depth: depth of residual DConv branch.
+ dconv_comp: compression of DConv branch.
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
+ dconv_init: initial scale for the DConv branch LayerScale.
+ normalize (bool): normalizes the input audio on the fly, and scales back
+ the output by the same amount.
+ resample (bool): upsample x2 the input and downsample /2 the output.
+ rescale (float): rescale initial weights of convolutions
+ to get their standard deviation closer to `rescale`.
+ samplerate (int): stored as meta information for easing
+ future evaluations of the model.
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
+ This is used by `demucs.apply.apply_model`.
+ """
+
+ super().__init__()
+ self.audio_channels = audio_channels
+ self.sources = sources
+ self.kernel_size = kernel_size
+ self.context = context
+ self.stride = stride
+ self.depth = depth
+ self.resample = resample
+ self.channels = channels
+ self.normalize = normalize
+ self.samplerate = samplerate
+ self.segment = segment
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+ self.skip_scales = nn.ModuleList()
+
+ if glu:
+ activation = nn.GLU(dim=1)
+ ch_scale = 2
+ else:
+ activation = nn.ReLU()
+ ch_scale = 1
+ if gelu:
+ act2 = nn.GELU
+ else:
+ act2 = nn.ReLU
+
+ in_channels = audio_channels
+ padding = 0
+ for index in range(depth):
+ norm_fn = lambda d: nn.Identity() # noqa
+ if index >= norm_starts:
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
+
+ encode = []
+ encode += [
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
+ norm_fn(channels),
+ act2(),
+ ]
+ attn = index >= dconv_attn
+ lstm = index >= dconv_lstm
+ if dconv_mode & 1:
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
+ compress=dconv_comp, attn=attn, lstm=lstm)]
+ if rewrite:
+ encode += [
+ nn.Conv1d(channels, ch_scale * channels, 1),
+ norm_fn(ch_scale * channels), activation]
+ self.encoder.append(nn.Sequential(*encode))
+
+ decode = []
+ if index > 0:
+ out_channels = in_channels
+ else:
+ out_channels = len(self.sources) * audio_channels
+ if rewrite:
+ decode += [
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
+ norm_fn(ch_scale * channels), activation]
+ if dconv_mode & 2:
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
+ compress=dconv_comp, attn=attn, lstm=lstm)]
+ decode += [nn.ConvTranspose1d(channels, out_channels,
+ kernel_size, stride, padding=padding)]
+ if index > 0:
+ decode += [norm_fn(out_channels), act2()]
+ self.decoder.insert(0, nn.Sequential(*decode))
+ in_channels = channels
+ channels = int(growth * channels)
+
+ channels = in_channels
+ if lstm_layers:
+ self.lstm = BLSTM(channels, lstm_layers)
+ else:
+ self.lstm = None
+
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ def valid_length(self, length):
+ """
+ Return the nearest valid length to use with the model so that
+ there is no time steps left over in a convolution, e.g. for all
+ layers, size of the input - kernel_size % stride = 0.
+
+ Note that input are automatically padded if necessary to ensure that the output
+ has the same length as the input.
+ """
+ if self.resample:
+ length *= 2
+
+ for _ in range(self.depth):
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
+ length = max(1, length)
+
+ for idx in range(self.depth):
+ length = (length - 1) * self.stride + self.kernel_size
+
+ if self.resample:
+ length = math.ceil(length / 2)
+ return int(length)
+
+ def forward(self, mix):
+ x = mix
+ length = x.shape[-1]
+
+ if self.normalize:
+ mono = mix.mean(dim=1, keepdim=True)
+ mean = mono.mean(dim=-1, keepdim=True)
+ std = mono.std(dim=-1, keepdim=True)
+ x = (x - mean) / (1e-5 + std)
+ else:
+ mean = 0
+ std = 1
+
+ delta = self.valid_length(length) - length
+ x = F.pad(x, (delta // 2, delta - delta // 2))
+
+ if self.resample:
+ x = julius.resample_frac(x, 1, 2)
+
+ saved = []
+ for encode in self.encoder:
+ x = encode(x)
+ saved.append(x)
+
+ if self.lstm:
+ x = self.lstm(x)
+
+ for decode in self.decoder:
+ skip = saved.pop(-1)
+ skip = center_trim(skip, x)
+ x = decode(x + skip)
+
+ if self.resample:
+ x = julius.resample_frac(x, 2, 1)
+ x = x * std + mean
+ x = center_trim(x, length)
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
+ return x
+
+ def load_state_dict(self, state, strict=True):
+ # fix a mismatch with previous generation Demucs models.
+ for idx in range(self.depth):
+ for a in ['encoder', 'decoder']:
+ for b in ['bias', 'weight']:
+ new = f'{a}.{idx}.3.{b}'
+ old = f'{a}.{idx}.2.{b}'
+ if old in state and new not in state:
+ state[new] = state.pop(old)
+ super().load_state_dict(state, strict=strict)
diff --git a/demucs/distrib.py b/demucs/distrib.py
new file mode 100644
index 00000000..dc1576cb
--- /dev/null
+++ b/demucs/distrib.py
@@ -0,0 +1,100 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Distributed training utilities.
+"""
+import logging
+import pickle
+
+import numpy as np
+import torch
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import DataLoader, Subset
+from torch.nn.parallel.distributed import DistributedDataParallel
+
+from dora import distrib as dora_distrib
+
+logger = logging.getLogger(__name__)
+rank = 0
+world_size = 1
+
+
+def init():
+ global rank, world_size
+ if not torch.distributed.is_initialized():
+ dora_distrib.init()
+ rank = dora_distrib.rank()
+ world_size = dora_distrib.world_size()
+
+
+def average(metrics, count=1.):
+ if isinstance(metrics, dict):
+ keys, values = zip(*sorted(metrics.items()))
+ values = average(values, count)
+ return dict(zip(keys, values))
+ if world_size == 1:
+ return metrics
+ tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
+ tensor *= count
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
+ return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
+
+
+def wrap(model):
+ if world_size == 1:
+ return model
+ else:
+ return DistributedDataParallel(
+ model,
+ # find_unused_parameters=True,
+ device_ids=[torch.cuda.current_device()],
+ output_device=torch.cuda.current_device())
+
+
+def barrier():
+ if world_size > 1:
+ torch.distributed.barrier()
+
+
+def share(obj=None, src=0):
+ if world_size == 1:
+ return obj
+ size = torch.empty(1, device='cuda', dtype=torch.long)
+ if rank == src:
+ dump = pickle.dumps(obj)
+ size[0] = len(dump)
+ torch.distributed.broadcast(size, src=src)
+ # size variable is now set to the length of pickled obj in all processes
+
+ if rank == src:
+ buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
+ else:
+ buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
+ torch.distributed.broadcast(buffer, src=src)
+ # buffer variable is now set to pickled obj in all processes
+
+ if rank != src:
+ obj = pickle.loads(buffer.cpu().numpy().tobytes())
+ logger.debug(f"Shared object of size {len(buffer)}")
+ return obj
+
+
+def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
+ """
+ Create a dataloader properly in case of distributed training.
+ If a gradient is going to be computed you must set `shuffle=True`.
+ """
+ if world_size == 1:
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
+
+ if shuffle:
+ # train means we will compute backward, we use DistributedSampler
+ sampler = DistributedSampler(dataset)
+ # We ignore shuffle, DistributedSampler already shuffles
+ return klass(dataset, *args, **kwargs, sampler=sampler)
+ else:
+ # We make a manual shard, as DistributedSampler otherwise replicate some examples
+ dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
diff --git a/demucs/ema.py b/demucs/ema.py
new file mode 100644
index 00000000..101bee02
--- /dev/null
+++ b/demucs/ema.py
@@ -0,0 +1,66 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Inspired from https://github.com/rwightman/pytorch-image-models
+from contextlib import contextmanager
+
+import torch
+
+from .states import swap_state
+
+
+class ModelEMA:
+ """
+ Perform EMA on a model. You can switch to the EMA weights temporarily
+ with the `swap` method.
+
+ ema = ModelEMA(model)
+ with ema.swap():
+ # compute valid metrics with averaged model.
+ """
+ def __init__(self, model, decay=0.9999, unbias=True, device='cpu'):
+ self.decay = decay
+ self.model = model
+ self.state = {}
+ self.count = 0
+ self.device = device
+ self.unbias = unbias
+
+ self._init()
+
+ def _init(self):
+ for key, val in self.model.state_dict().items():
+ if val.dtype != torch.float32:
+ continue
+ device = self.device or val.device
+ if key not in self.state:
+ self.state[key] = val.detach().to(device, copy=True)
+
+ def update(self):
+ if self.unbias:
+ self.count = self.count * self.decay + 1
+ w = 1 / self.count
+ else:
+ w = 1 - self.decay
+ for key, val in self.model.state_dict().items():
+ if val.dtype != torch.float32:
+ continue
+ device = self.device or val.device
+ self.state[key].mul_(1 - w)
+ self.state[key].add_(val.detach().to(device), alpha=w)
+
+ @contextmanager
+ def swap(self):
+ with swap_state(self.model, self.state):
+ yield
+
+ def state_dict(self):
+ return {'state': self.state, 'count': self.count}
+
+ def load_state_dict(self, state):
+ self.count = state['count']
+ for k, v in state['state'].items():
+ self.state[k].copy_(v)
diff --git a/demucs/evaluate.py b/demucs/evaluate.py
new file mode 100755
index 00000000..fa2ff453
--- /dev/null
+++ b/demucs/evaluate.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Test time evaluation, either using the original SDR from [Vincent et al. 2006]
+or the newest SDR definition from the MDX 2021 competition (this one will
+be reported as `nsdr` for `new sdr`).
+"""
+
+from concurrent import futures
+import logging
+
+from dora.log import LogProgress
+import numpy as np
+import musdb
+import museval
+import torch as th
+
+from .apply import apply_model
+from .audio import convert_audio, save_audio
+from . import distrib
+from .utils import DummyPoolExecutor
+
+
+logger = logging.getLogger(__name__)
+
+
+def new_sdr(references, estimates):
+ """
+ Compute the SDR according to the MDX challenge definition.
+ Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
+ """
+ assert references.dim() == 4
+ assert estimates.dim() == 4
+ delta = 1e-7 # avoid numerical errors
+ num = th.sum(th.square(references), dim=(2, 3))
+ den = th.sum(th.square(references - estimates), dim=(2, 3))
+ num += delta
+ den += delta
+ scores = 10 * th.log10(num / den)
+ return scores
+
+
+def eval_track(references, estimates, win, hop, compute_sdr=True):
+ references = references.transpose(1, 2).double()
+ estimates = estimates.transpose(1, 2).double()
+
+ new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]
+
+ if not compute_sdr:
+ return None, new_scores
+ else:
+ references = references.numpy()
+ estimates = estimates.numpy()
+ scores = museval.metrics.bss_eval(
+ references, estimates,
+ compute_permutation=False,
+ window=win,
+ hop=hop,
+ framewise_filters=False,
+ bsseval_sources_version=False)[:-1]
+ return scores, new_scores
+
+
+def evaluate(solver, compute_sdr=False):
+ """
+ Evaluate model using museval.
+ compute_sdr=False means using only the MDX definition of the SDR, which
+ is much faster to evaluate.
+ """
+
+ args = solver.args
+
+ output_dir = solver.folder / "results"
+ output_dir.mkdir(exist_ok=True, parents=True)
+ json_folder = solver.folder / "results/test"
+ json_folder.mkdir(exist_ok=True, parents=True)
+
+ # we load tracks from the original musdb set
+ if args.test.nonhq is None:
+ test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True)
+ else:
+ test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False)
+ src_rate = args.dset.musdb_samplerate
+
+ eval_device = 'cpu'
+
+ model = solver.model
+ win = int(1. * model.samplerate)
+ hop = int(1. * model.samplerate)
+
+ indexes = range(distrib.rank, len(test_set), distrib.world_size)
+ indexes = LogProgress(logger, indexes, updates=args.misc.num_prints,
+ name='Eval')
+ pendings = []
+
+ pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor
+ with pool(args.test.workers) as pool:
+ for index in indexes:
+ track = test_set.tracks[index]
+
+ mix = th.from_numpy(track.audio).t().float()
+ if mix.dim() == 1:
+ mix = mix[None]
+ mix = mix.to(solver.device)
+ ref = mix.mean(dim=0) # mono mixture
+ mix = (mix - ref.mean()) / ref.std()
+ mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
+ estimates = apply_model(model, mix[None],
+ shifts=args.test.shifts, split=args.test.split,
+ overlap=args.test.overlap)[0]
+ estimates = estimates * ref.std() + ref.mean()
+ estimates = estimates.to(eval_device)
+
+ references = th.stack(
+ [th.from_numpy(track.targets[name].audio).t() for name in model.sources])
+ if references.dim() == 2:
+ references = references[:, None]
+ references = references.to(eval_device)
+ references = convert_audio(references, src_rate,
+ model.samplerate, model.audio_channels)
+ if args.test.save:
+ folder = solver.folder / "wav" / track.name
+ folder.mkdir(exist_ok=True, parents=True)
+ for name, estimate in zip(model.sources, estimates):
+ save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate)
+
+ pendings.append((track.name, pool.submit(
+ eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr)))
+
+ pendings = LogProgress(logger, pendings, updates=args.misc.num_prints,
+ name='Eval (BSS)')
+ tracks = {}
+ for track_name, pending in pendings:
+ pending = pending.result()
+ scores, nsdrs = pending
+ tracks[track_name] = {}
+ for idx, target in enumerate(model.sources):
+ tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]}
+ if scores is not None:
+ (sdr, isr, sir, sar) = scores
+ for idx, target in enumerate(model.sources):
+ values = {
+ "SDR": sdr[idx].tolist(),
+ "SIR": sir[idx].tolist(),
+ "ISR": isr[idx].tolist(),
+ "SAR": sar[idx].tolist()
+ }
+ tracks[track_name][target].update(values)
+
+ all_tracks = {}
+ for src in range(distrib.world_size):
+ all_tracks.update(distrib.share(tracks, src))
+
+ result = {}
+ metric_names = next(iter(all_tracks.values()))[model.sources[0]]
+ for metric_name in metric_names:
+ avg = 0
+ avg_of_medians = 0
+ for source in model.sources:
+ medians = [
+ np.nanmedian(all_tracks[track][source][metric_name])
+ for track in all_tracks.keys()]
+ mean = np.mean(medians)
+ median = np.median(medians)
+ result[metric_name.lower() + "_" + source] = mean
+ result[metric_name.lower() + "_med" + "_" + source] = median
+ avg += mean / len(model.sources)
+ avg_of_medians += median / len(model.sources)
+ result[metric_name.lower()] = avg
+ result[metric_name.lower() + "_med"] = avg_of_medians
+ return result
diff --git a/demucs/grids/__init__.py b/demucs/grids/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/demucs/grids/_explorers.py b/demucs/grids/_explorers.py
new file mode 100644
index 00000000..ec3a858d
--- /dev/null
+++ b/demucs/grids/_explorers.py
@@ -0,0 +1,64 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dora import Explorer
+import treetable as tt
+
+
+class MyExplorer(Explorer):
+ test_metrics = ['nsdr', 'sdr_med']
+
+ def get_grid_metrics(self):
+ """Return the metrics that should be displayed in the tracking table.
+ """
+ return [
+ tt.group("train", [
+ tt.leaf("epoch"),
+ tt.leaf("reco", ".3f"),
+ ], align=">"),
+ tt.group("valid", [
+ tt.leaf("penalty", ".1f"),
+ tt.leaf("ms", ".1f"),
+ tt.leaf("reco", ".2%"),
+ tt.leaf("breco", ".2%"),
+ tt.leaf("b_nsdr", ".2f"),
+ # tt.leaf("b_nsdr_drums", ".2f"),
+ # tt.leaf("b_nsdr_bass", ".2f"),
+ # tt.leaf("b_nsdr_other", ".2f"),
+ # tt.leaf("b_nsdr_vocals", ".2f"),
+ ], align=">"),
+ tt.group("test", [
+ tt.leaf(name, ".2f")
+ for name in self.test_metrics
+ ], align=">")
+ ]
+
+ def process_history(self, history):
+ train = {
+ 'epoch': len(history),
+ }
+ valid = {}
+ test = {}
+ best_v_main = float('inf')
+ breco = float('inf')
+ for metrics in history:
+ train.update(metrics['train'])
+ valid.update(metrics['valid'])
+ if 'main' in metrics['valid']:
+ best_v_main = min(best_v_main, metrics['valid']['main']['loss'])
+ valid['bmain'] = best_v_main
+ valid['breco'] = min(breco, metrics['valid']['reco'])
+ breco = valid['breco']
+ if (metrics['valid']['loss'] == metrics['valid']['best'] or
+ metrics['valid'].get('nsdr') == metrics['valid']['best']):
+ for k, v in metrics['valid'].items():
+ if k.startswith('reco_'):
+ valid['b_' + k[len('reco_'):]] = v
+ if k.startswith('nsdr'):
+ valid[f'b_{k}'] = v
+ if 'test' in metrics:
+ test.update(metrics['test'])
+ metrics = history[-1]
+ return {"train": train, "valid": valid, "test": test}
diff --git a/demucs/grids/mdx.py b/demucs/grids/mdx.py
new file mode 100644
index 00000000..62d447f1
--- /dev/null
+++ b/demucs/grids/mdx.py
@@ -0,0 +1,33 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Main training for the Track A MDX models.
+"""
+
+from ._explorers import MyExplorer
+from ..train import main
+
+
+TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68']
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=3 * 24 * 60,
+ partition='learnlab')
+
+ # Reproduce results from MDX competition Track A
+ # This trains the first round of models. Once this is trained,
+ # you will need to schedule `mdx_refine`.
+ for sig in TRACK_A:
+ xp = main.get_xp_from_sig(sig)
+ parent = xp.cfg.continue_from
+ xp = main.get_xp_from_sig(parent)
+ launcher(xp.argv)
+ launcher(xp.argv, {'quant.diffq': 1e-4})
+ launcher(xp.argv, {'quant.diffq': 3e-4})
diff --git a/demucs/grids/mdx_extra.py b/demucs/grids/mdx_extra.py
new file mode 100644
index 00000000..b99a37b0
--- /dev/null
+++ b/demucs/grids/mdx_extra.py
@@ -0,0 +1,36 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Main training for the Track A MDX models.
+"""
+
+from ._explorers import MyExplorer
+from ..train import main
+
+TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08']
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=3 * 24 * 60,
+ partition='learnlab')
+
+ # Reproduce results from MDX competition Track A
+ # This trains the first round of models. Once this is trained,
+ # you will need to schedule `mdx_refine`.
+ for sig in TRACK_B:
+ while sig is not None:
+ xp = main.get_xp_from_sig(sig)
+ sig = xp.cfg.continue_from
+
+ for dset in ['extra44', 'extra_test']:
+ sub = launcher.bind(xp.argv, dset=dset)
+ sub()
+ if dset == 'extra_test':
+ sub({'quant.diffq': 1e-4})
+ sub({'quant.diffq': 3e-4})
diff --git a/demucs/grids/mdx_refine.py b/demucs/grids/mdx_refine.py
new file mode 100644
index 00000000..f62da1de
--- /dev/null
+++ b/demucs/grids/mdx_refine.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Main training for the Track A MDX models.
+"""
+
+from ._explorers import MyExplorer
+from .mdx import TRACK_A
+from ..train import main
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=3 * 24 * 60,
+ partition='learnlab')
+
+ # Reproduce results from MDX competition Track A
+ # WARNING: all the experiments in the `mdx` grid must have completed.
+ for sig in TRACK_A:
+ xp = main.get_xp_from_sig(sig)
+ launcher(xp.argv)
+ for diffq in [1e-4, 3e-4]:
+ xp_src = main.get_xp_from_sig(xp.cfg.continue_from)
+ q_argv = [f'quant.diffq={diffq}']
+ actual_src = main.get_xp(xp_src.argv + q_argv)
+ actual_src.link.load()
+ assert len(actual_src.link.history) == actual_src.cfg.epochs
+ argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"']
+ launcher(argv)
diff --git a/demucs/grids/mmi.py b/demucs/grids/mmi.py
new file mode 100644
index 00000000..d75aa2b6
--- /dev/null
+++ b/demucs/grids/mmi.py
@@ -0,0 +1,69 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import MyExplorer
+from dora import Launcher
+
+
+@MyExplorer
+def explorer(launcher: Launcher):
+ launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days
+
+ sub = launcher.bind_(
+ {
+ "dset": "extra_mmi_goodclean",
+ "test.shifts": 0,
+ "model": "htdemucs",
+ "htdemucs.dconv_mode": 3,
+ "htdemucs.depth": 4,
+ "htdemucs.t_dropout": 0.02,
+ "htdemucs.t_layers": 5,
+ "max_batches": 800,
+ "ema.epoch": [0.9, 0.95],
+ "ema.batch": [0.9995, 0.9999],
+ "dset.segment": 10,
+ "batch_size": 32,
+ }
+ )
+ sub({"model": "hdemucs"})
+ sub({"model": "hdemucs", "dset": "extra44"})
+ sub({"model": "hdemucs", "dset": "musdb44"})
+
+ sparse = {
+ 'batch_size': 3 * 8,
+ 'augment.remix.group_size': 3,
+ 'htdemucs.t_auto_sparsity': True,
+ 'htdemucs.t_sparse_self_attn': True,
+ 'htdemucs.t_sparse_cross_attn': True,
+ 'htdemucs.t_sparsity': 0.9,
+ "htdemucs.t_layers": 7
+ }
+
+ with launcher.job_array():
+ for transf_layers in [5, 7]:
+ for bottom_channels in [0, 512]:
+ sub = launcher.bind({
+ "htdemucs.t_layers": transf_layers,
+ "htdemucs.bottom_channels": bottom_channels,
+ })
+ if bottom_channels == 0 and transf_layers == 5:
+ sub({"augment.remix.proba": 0.0})
+ sub({
+ "augment.repitch.proba": 0.0,
+ # when doing repitching, we trim the outut to align on the
+ # highest change of BPM. When removing repitching,
+ # we simulate it here to ensure the training context is the same.
+ # Another second is lost for all experiments due to the random
+ # shift augmentation.
+ "dset.segment": 10 * 0.88})
+ elif bottom_channels == 512 and transf_layers == 5:
+ sub(dset="musdb44")
+ sub(dset="extra44")
+ # Sparse kernel XP, currently not released as kernels are still experimental.
+ sub(sparse, {'dset.segment': 15, "htdemucs.t_layers": 7})
+
+ for duration in [5, 10, 15]:
+ sub({"dset.segment": duration})
diff --git a/demucs/grids/mmi_ft.py b/demucs/grids/mmi_ft.py
new file mode 100644
index 00000000..73e488b5
--- /dev/null
+++ b/demucs/grids/mmi_ft.py
@@ -0,0 +1,55 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import MyExplorer
+from dora import Launcher
+from demucs import train
+
+
+def get_sub(launcher, sig):
+ xp = train.main.get_xp_from_sig(sig)
+ sub = launcher.bind(xp.argv)
+ sub()
+ sub.bind_({
+ 'continue_from': sig,
+ 'continue_best': True})
+ return sub
+
+
+@MyExplorer
+def explorer(launcher: Launcher):
+ launcher.slurm_(gpus=4, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days
+ ft = {
+ 'optim.lr': 1e-4,
+ 'augment.remix.proba': 0,
+ 'augment.scale.proba': 0,
+ 'augment.shift_same': True,
+ 'htdemucs.t_weight_decay': 0.05,
+ 'batch_size': 8,
+ 'optim.clip_grad': 5,
+ 'optim.optim': 'adamw',
+ 'epochs': 50,
+ 'dset.wav2_valid': True,
+ 'ema.epoch': [], # let's make valid a bit faster
+ }
+ with launcher.job_array():
+ for sig in ['2899e11a']:
+ sub = get_sub(launcher, sig)
+ sub.bind_(ft)
+ for segment in [15, 18]:
+ for source in range(4):
+ w = [0] * 4
+ w[source] = 1
+ sub({'weights': w, 'dset.segment': segment})
+
+ for sig in ['955717e8']:
+ sub = get_sub(launcher, sig)
+ sub.bind_(ft)
+ for segment in [10, 15]:
+ for source in range(4):
+ w = [0] * 4
+ w[source] = 1
+ sub({'weights': w, 'dset.segment': segment})
diff --git a/demucs/grids/repro.py b/demucs/grids/repro.py
new file mode 100644
index 00000000..21d33fce
--- /dev/null
+++ b/demucs/grids/repro.py
@@ -0,0 +1,50 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Easier training for reproducibility
+"""
+
+from ._explorers import MyExplorer
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=3 * 24 * 60,
+ partition='devlab,learnlab')
+
+ launcher.bind_({'ema.epoch': [0.9, 0.95]})
+ launcher.bind_({'ema.batch': [0.9995, 0.9999]})
+ launcher.bind_({'epochs': 600})
+
+ base = {'model': 'demucs', 'demucs.dconv_mode': 0, 'demucs.gelu': False,
+ 'demucs.lstm_layers': 2}
+ newt = {'model': 'demucs', 'demucs.normalize': True}
+ hdem = {'model': 'hdemucs'}
+ svd = {'svd.penalty': 1e-5, 'svd': 'base2'}
+
+ with launcher.job_array():
+ for model in [base, newt, hdem]:
+ sub = launcher.bind(model)
+ if model is base:
+ # Training the v2 Demucs on MusDB HQ
+ sub(epochs=360)
+ continue
+
+ # those two will be used in the repro_mdx_a bag of models.
+ sub(svd)
+ sub(svd, seed=43)
+ if model == newt:
+ # Ablation study
+ sub()
+ abl = sub.bind(svd)
+ abl({'ema.epoch': [], 'ema.batch': []})
+ abl({'demucs.dconv_lstm': 10})
+ abl({'demucs.dconv_attn': 10})
+ abl({'demucs.dconv_attn': 10, 'demucs.dconv_lstm': 10, 'demucs.lstm_layers': 2})
+ abl({'demucs.dconv_mode': 0})
+ abl({'demucs.gelu': False})
diff --git a/demucs/grids/repro_ft.py b/demucs/grids/repro_ft.py
new file mode 100644
index 00000000..7bb4ee89
--- /dev/null
+++ b/demucs/grids/repro_ft.py
@@ -0,0 +1,46 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Fine tuning experiments
+"""
+
+from ._explorers import MyExplorer
+from ..train import main
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=300,
+ partition='devlab,learnlab')
+
+ # Mus
+ launcher.slurm_(constraint='volta32gb')
+
+ grid = "repro"
+ folder = main.dora.dir / "grids" / grid
+
+ for sig in folder.iterdir():
+ if not sig.is_symlink():
+ continue
+ xp = main.get_xp_from_sig(sig)
+ xp.link.load()
+ if len(xp.link.history) != xp.cfg.epochs:
+ continue
+ sub = launcher.bind(xp.argv, [f'continue_from="{xp.sig}"'])
+ sub.bind_({'ema.epoch': [0.9, 0.95], 'ema.batch': [0.9995, 0.9999]})
+ sub.bind_({'test.every': 1, 'test.sdr': True, 'epochs': 4})
+ sub.bind_({'dset.segment': 28, 'dset.shift': 2})
+ sub.bind_({'batch_size': 32})
+ auto = {'dset': 'auto_mus'}
+ auto.update({'augment.remix.proba': 0, 'augment.scale.proba': 0,
+ 'augment.shift_same': True})
+ sub.bind_(auto)
+ sub.bind_({'batch_size': 16})
+ sub.bind_({'optim.lr': 1e-4})
+ sub.bind_({'model_segment': 44})
+ sub()
diff --git a/demucs/grids/sdx23.py b/demucs/grids/sdx23.py
new file mode 100644
index 00000000..3bdb4191
--- /dev/null
+++ b/demucs/grids/sdx23.py
@@ -0,0 +1,19 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import MyExplorer
+from dora import Launcher
+
+
+@MyExplorer
+def explorer(launcher: Launcher):
+ launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="speechgpt,learnfair",
+ mem_per_gpu=None, constraint='')
+ launcher.bind_({"dset.use_musdb": False})
+
+ with launcher.job_array():
+ launcher(dset='sdx23_bleeding')
+ launcher(dset='sdx23_labelnoise')
diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py
new file mode 100644
index 00000000..9992b60a
--- /dev/null
+++ b/demucs/hdemucs.py
@@ -0,0 +1,796 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+This code contains the spectrogram and Hybrid version of Demucs.
+"""
+from copy import deepcopy
+import math
+import typing as tp
+
+from openunmix.filtering import wiener
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .demucs import DConv, rescale_module
+from .states import capture_init
+from .spec import spectro, ispectro
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
+ x0 = x
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ if mode == 'reflect':
+ max_pad = max(padding_left, padding_right)
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ extra_pad_right = min(padding_right, extra_pad)
+ extra_pad_left = extra_pad - extra_pad_right
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
+ out = F.pad(x, paddings, mode, value)
+ assert out.shape[-1] == length + padding_left + padding_right
+ assert (out[..., padding_left: padding_left + length] == x0).all()
+ return out
+
+
+class ScaledEmbedding(nn.Module):
+ """
+ Boost learning rate for embeddings (with `scale`).
+ Also, can make embeddings continuous with `smooth`.
+ """
+ def __init__(self, num_embeddings: int, embedding_dim: int,
+ scale: float = 10., smooth=False):
+ super().__init__()
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
+ if smooth:
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
+ self.embedding.weight.data[:] = weight
+ self.embedding.weight.data /= scale
+ self.scale = scale
+
+ @property
+ def weight(self):
+ return self.embedding.weight * self.scale
+
+ def forward(self, x):
+ out = self.embedding(x) * self.scale
+ return out
+
+
+class HEncLayer(nn.Module):
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
+ freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
+ rewrite=True):
+ """Encoder layer. This used both by the time and the frequency branch.
+
+ Args:
+ chin: number of input channels.
+ chout: number of output channels.
+ norm_groups: number of groups for group norm.
+ empty: used to make a layer with just the first conv. this is used
+ before merging the time and freq. branches.
+ freq: this is acting on frequencies.
+ dconv: insert DConv residual branches.
+ norm: use GroupNorm.
+ context: context size for the 1x1 conv.
+ dconv_kw: list of kwargs for the DConv class.
+ pad: pad the input. Padding is done so that the output size is
+ always the input size / stride.
+ rewrite: add 1x1 conv at the end of the layer.
+ """
+ super().__init__()
+ norm_fn = lambda d: nn.Identity() # noqa
+ if norm:
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
+ if pad:
+ pad = kernel_size // 4
+ else:
+ pad = 0
+ klass = nn.Conv1d
+ self.freq = freq
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.empty = empty
+ self.norm = norm
+ self.pad = pad
+ if freq:
+ kernel_size = [kernel_size, 1]
+ stride = [stride, 1]
+ pad = [pad, 0]
+ klass = nn.Conv2d
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
+ if self.empty:
+ return
+ self.norm1 = norm_fn(chout)
+ self.rewrite = None
+ if rewrite:
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
+ self.norm2 = norm_fn(2 * chout)
+
+ self.dconv = None
+ if dconv:
+ self.dconv = DConv(chout, **dconv_kw)
+
+ def forward(self, x, inject=None):
+ """
+ `inject` is used to inject the result from the time branch into the frequency branch,
+ when both have the same stride.
+ """
+ if not self.freq and x.dim() == 4:
+ B, C, Fr, T = x.shape
+ x = x.view(B, -1, T)
+
+ if not self.freq:
+ le = x.shape[-1]
+ if not le % self.stride == 0:
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
+ y = self.conv(x)
+ if self.empty:
+ return y
+ if inject is not None:
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
+ if inject.dim() == 3 and y.dim() == 4:
+ inject = inject[:, :, None]
+ y = y + inject
+ y = F.gelu(self.norm1(y))
+ if self.dconv:
+ if self.freq:
+ B, C, Fr, T = y.shape
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
+ y = self.dconv(y)
+ if self.freq:
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
+ if self.rewrite:
+ z = self.norm2(self.rewrite(y))
+ z = F.glu(z, dim=1)
+ else:
+ z = y
+ return z
+
+
+class MultiWrap(nn.Module):
+ """
+ Takes one layer and replicate it N times. each replica will act
+ on a frequency band. All is done so that if the N replica have the same weights,
+ then this is exactly equivalent to applying the original module on all frequencies.
+
+ This is a bit over-engineered to avoid edge artifacts when splitting
+ the frequency bands, but it is possible the naive implementation would work as well...
+ """
+ def __init__(self, layer, split_ratios):
+ """
+ Args:
+ layer: module to clone, must be either HEncLayer or HDecLayer.
+ split_ratios: list of float indicating which ratio to keep for each band.
+ """
+ super().__init__()
+ self.split_ratios = split_ratios
+ self.layers = nn.ModuleList()
+ self.conv = isinstance(layer, HEncLayer)
+ assert not layer.norm
+ assert layer.freq
+ assert layer.pad
+ if not self.conv:
+ assert not layer.context_freq
+ for k in range(len(split_ratios) + 1):
+ lay = deepcopy(layer)
+ if self.conv:
+ lay.conv.padding = (0, 0)
+ else:
+ lay.pad = False
+ for m in lay.modules():
+ if hasattr(m, 'reset_parameters'):
+ m.reset_parameters()
+ self.layers.append(lay)
+
+ def forward(self, x, skip=None, length=None):
+ B, C, Fr, T = x.shape
+
+ ratios = list(self.split_ratios) + [1]
+ start = 0
+ outs = []
+ for ratio, layer in zip(ratios, self.layers):
+ if self.conv:
+ pad = layer.kernel_size // 4
+ if ratio == 1:
+ limit = Fr
+ frames = -1
+ else:
+ limit = int(round(Fr * ratio))
+ le = limit - start
+ if start == 0:
+ le += pad
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
+ if start == 0:
+ limit -= pad
+ assert limit - start > 0, (limit, start)
+ assert limit <= Fr, (limit, Fr)
+ y = x[:, :, start:limit, :]
+ if start == 0:
+ y = F.pad(y, (0, 0, pad, 0))
+ if ratio == 1:
+ y = F.pad(y, (0, 0, 0, pad))
+ outs.append(layer(y))
+ start = limit - layer.kernel_size + layer.stride
+ else:
+ if ratio == 1:
+ limit = Fr
+ else:
+ limit = int(round(Fr * ratio))
+ last = layer.last
+ layer.last = True
+
+ y = x[:, :, start:limit]
+ s = skip[:, :, start:limit]
+ out, _ = layer(y, s, None)
+ if outs:
+ outs[-1][:, :, -layer.stride:] += (
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
+ out = out[:, :, layer.stride:]
+ if ratio == 1:
+ out = out[:, :, :-layer.stride // 2, :]
+ if start == 0:
+ out = out[:, :, layer.stride // 2:, :]
+ outs.append(out)
+ layer.last = last
+ start = limit
+ out = torch.cat(outs, dim=2)
+ if not self.conv and not last:
+ out = F.gelu(out)
+ if self.conv:
+ return out
+ else:
+ return out, None
+
+
+class HDecLayer(nn.Module):
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
+ context_freq=True, rewrite=True):
+ """
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
+ """
+ super().__init__()
+ norm_fn = lambda d: nn.Identity() # noqa
+ if norm:
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
+ if pad:
+ pad = kernel_size // 4
+ else:
+ pad = 0
+ self.pad = pad
+ self.last = last
+ self.freq = freq
+ self.chin = chin
+ self.empty = empty
+ self.stride = stride
+ self.kernel_size = kernel_size
+ self.norm = norm
+ self.context_freq = context_freq
+ klass = nn.Conv1d
+ klass_tr = nn.ConvTranspose1d
+ if freq:
+ kernel_size = [kernel_size, 1]
+ stride = [stride, 1]
+ klass = nn.Conv2d
+ klass_tr = nn.ConvTranspose2d
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
+ self.norm2 = norm_fn(chout)
+ if self.empty:
+ return
+ self.rewrite = None
+ if rewrite:
+ if context_freq:
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
+ else:
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
+ [0, context])
+ self.norm1 = norm_fn(2 * chin)
+
+ self.dconv = None
+ if dconv:
+ self.dconv = DConv(chin, **dconv_kw)
+
+ def forward(self, x, skip, length):
+ if self.freq and x.dim() == 3:
+ B, C, T = x.shape
+ x = x.view(B, self.chin, -1, T)
+
+ if not self.empty:
+ x = x + skip
+
+ if self.rewrite:
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
+ else:
+ y = x
+ if self.dconv:
+ if self.freq:
+ B, C, Fr, T = y.shape
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
+ y = self.dconv(y)
+ if self.freq:
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
+ else:
+ y = x
+ assert skip is None
+ z = self.norm2(self.conv_tr(y))
+ if self.freq:
+ if self.pad:
+ z = z[..., self.pad:-self.pad, :]
+ else:
+ z = z[..., self.pad:self.pad + length]
+ assert z.shape[-1] == length, (z.shape[-1], length)
+ if not self.last:
+ z = F.gelu(z)
+ return z, y
+
+
+class HDemucs(nn.Module):
+ """
+ Spectrogram and hybrid Demucs model.
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
+ Frequency layers can still access information across time steps thanks to the DConv residual.
+
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
+
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
+ Open Unmix implementation [Stoter et al. 2019].
+
+ The loss is always on the temporal domain, by backpropagating through the above
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
+ contribution, without changing the one from the waveform, which will lead to worse performance.
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
+ hybrid models.
+
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
+
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
+ """
+ @capture_init
+ def __init__(self,
+ sources,
+ # Channels
+ audio_channels=2,
+ channels=48,
+ channels_time=None,
+ growth=2,
+ # STFT
+ nfft=4096,
+ wiener_iters=0,
+ end_iters=0,
+ wiener_residual=False,
+ cac=True,
+ # Main structure
+ depth=6,
+ rewrite=True,
+ hybrid=True,
+ hybrid_old=False,
+ # Frequency branch
+ multi_freqs=None,
+ multi_freqs_depth=2,
+ freq_emb=0.2,
+ emb_scale=10,
+ emb_smooth=True,
+ # Convolutions
+ kernel_size=8,
+ time_stride=2,
+ stride=4,
+ context=1,
+ context_enc=0,
+ # Normalization
+ norm_starts=4,
+ norm_groups=4,
+ # DConv residual branch
+ dconv_mode=1,
+ dconv_depth=2,
+ dconv_comp=4,
+ dconv_attn=4,
+ dconv_lstm=4,
+ dconv_init=1e-4,
+ # Weight init
+ rescale=0.1,
+ # Metadata
+ samplerate=44100,
+ segment=4 * 10):
+ """
+ Args:
+ sources (list[str]): list of source names.
+ audio_channels (int): input/output audio channels.
+ channels (int): initial number of hidden channels.
+ channels_time: if not None, use a different `channels` value for the time branch.
+ growth: increase the number of hidden channels by this factor at each layer.
+ nfft: number of fft bins. Note that changing this require careful computation of
+ various shape parameters and will not work out of the box for hybrid models.
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
+ wiener_residual: add residual source before wiener filtering.
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
+ in input and output. no further processing is done before ISTFT.
+ depth (int): number of layers in the encoder and in the decoder.
+ rewrite (bool): add 1x1 convolution to each layer.
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
+ this bug to avoid retraining them.
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
+ layers will be wrapped.
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
+ the actual value controls the weight of the embedding.
+ emb_scale: equivalent to scaling the embedding learning rate
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
+ kernel_size: kernel_size for encoder and decoder layers.
+ stride: stride for encoder and decoder layers.
+ time_stride: stride for the final time layer, after the merge.
+ context: context for 1x1 conv in the decoder.
+ context_enc: context for 1x1 conv in the encoder.
+ norm_starts: layer at which group norm starts being used.
+ decoder layers are numbered in reverse order.
+ norm_groups: number of groups for group norm.
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
+ dconv_depth: depth of residual DConv branch.
+ dconv_comp: compression of DConv branch.
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
+ dconv_init: initial scale for the DConv branch LayerScale.
+ rescale: weight recaling trick
+
+ """
+ super().__init__()
+ self.cac = cac
+ self.wiener_residual = wiener_residual
+ self.audio_channels = audio_channels
+ self.sources = sources
+ self.kernel_size = kernel_size
+ self.context = context
+ self.stride = stride
+ self.depth = depth
+ self.channels = channels
+ self.samplerate = samplerate
+ self.segment = segment
+
+ self.nfft = nfft
+ self.hop_length = nfft // 4
+ self.wiener_iters = wiener_iters
+ self.end_iters = end_iters
+ self.freq_emb = None
+ self.hybrid = hybrid
+ self.hybrid_old = hybrid_old
+ if hybrid_old:
+ assert hybrid, "hybrid_old must come with hybrid=True"
+ if hybrid:
+ assert wiener_iters == end_iters
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ if hybrid:
+ self.tencoder = nn.ModuleList()
+ self.tdecoder = nn.ModuleList()
+
+ chin = audio_channels
+ chin_z = chin # number of channels for the freq branch
+ if self.cac:
+ chin_z *= 2
+ chout = channels_time or channels
+ chout_z = channels
+ freqs = nfft // 2
+
+ for index in range(depth):
+ lstm = index >= dconv_lstm
+ attn = index >= dconv_attn
+ norm = index >= norm_starts
+ freq = freqs > 1
+ stri = stride
+ ker = kernel_size
+ if not freq:
+ assert freqs == 1
+ ker = time_stride * 2
+ stri = time_stride
+
+ pad = True
+ last_freq = False
+ if freq and freqs <= kernel_size:
+ ker = freqs
+ pad = False
+ last_freq = True
+
+ kw = {
+ 'kernel_size': ker,
+ 'stride': stri,
+ 'freq': freq,
+ 'pad': pad,
+ 'norm': norm,
+ 'rewrite': rewrite,
+ 'norm_groups': norm_groups,
+ 'dconv_kw': {
+ 'lstm': lstm,
+ 'attn': attn,
+ 'depth': dconv_depth,
+ 'compress': dconv_comp,
+ 'init': dconv_init,
+ 'gelu': True,
+ }
+ }
+ kwt = dict(kw)
+ kwt['freq'] = 0
+ kwt['kernel_size'] = kernel_size
+ kwt['stride'] = stride
+ kwt['pad'] = True
+ kw_dec = dict(kw)
+ multi = False
+ if multi_freqs and index < multi_freqs_depth:
+ multi = True
+ kw_dec['context_freq'] = False
+
+ if last_freq:
+ chout_z = max(chout, chout_z)
+ chout = chout_z
+
+ enc = HEncLayer(chin_z, chout_z,
+ dconv=dconv_mode & 1, context=context_enc, **kw)
+ if hybrid and freq:
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
+ empty=last_freq, **kwt)
+ self.tencoder.append(tenc)
+
+ if multi:
+ enc = MultiWrap(enc, multi_freqs)
+ self.encoder.append(enc)
+ if index == 0:
+ chin = self.audio_channels * len(self.sources)
+ chin_z = chin
+ if self.cac:
+ chin_z *= 2
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
+ last=index == 0, context=context, **kw_dec)
+ if multi:
+ dec = MultiWrap(dec, multi_freqs)
+ if hybrid and freq:
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
+ last=index == 0, context=context, **kwt)
+ self.tdecoder.insert(0, tdec)
+ self.decoder.insert(0, dec)
+
+ chin = chout
+ chin_z = chout_z
+ chout = int(growth * chout)
+ chout_z = int(growth * chout_z)
+ if freq:
+ if freqs <= kernel_size:
+ freqs = 1
+ else:
+ freqs //= stride
+ if index == 0 and freq_emb:
+ self.freq_emb = ScaledEmbedding(
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
+ self.freq_emb_scale = freq_emb
+
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ def _spec(self, x):
+ hl = self.hop_length
+ nfft = self.nfft
+ x0 = x # noqa
+
+ if self.hybrid:
+ # We re-pad the signal in order to keep the property
+ # that the size of the output is exactly the size of the input
+ # divided by the stride (here hop_length), when divisible.
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
+ # which is not supported by torch.stft.
+ # Having all convolution operations follow this convention allow to easily
+ # align the time and frequency branches later on.
+ assert hl == nfft // 4
+ le = int(math.ceil(x.shape[-1] / hl))
+ pad = hl // 2 * 3
+ if not self.hybrid_old:
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
+ else:
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
+
+ z = spectro(x, nfft, hl)[..., :-1, :]
+ if self.hybrid:
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
+ z = z[..., 2:2+le]
+ return z
+
+ def _ispec(self, z, length=None, scale=0):
+ hl = self.hop_length // (4 ** scale)
+ z = F.pad(z, (0, 0, 0, 1))
+ if self.hybrid:
+ z = F.pad(z, (2, 2))
+ pad = hl // 2 * 3
+ if not self.hybrid_old:
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
+ else:
+ le = hl * int(math.ceil(length / hl))
+ x = ispectro(z, hl, length=le)
+ if not self.hybrid_old:
+ x = x[..., pad:pad + length]
+ else:
+ x = x[..., :length]
+ else:
+ x = ispectro(z, hl, length)
+ return x
+
+ def _magnitude(self, z):
+ # return the magnitude of the spectrogram, except when cac is True,
+ # in which case we just move the complex dimension to the channel one.
+ if self.cac:
+ B, C, Fr, T = z.shape
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
+ m = m.reshape(B, C * 2, Fr, T)
+ else:
+ m = z.abs()
+ return m
+
+ def _mask(self, z, m):
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
+ niters = self.wiener_iters
+ if self.cac:
+ B, S, C, Fr, T = m.shape
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
+ out = torch.view_as_complex(out.contiguous())
+ return out
+ if self.training:
+ niters = self.end_iters
+ if niters < 0:
+ z = z[:, None]
+ return z / (1e-8 + z.abs()) * m
+ else:
+ return self._wiener(m, z, niters)
+
+ def _wiener(self, mag_out, mix_stft, niters):
+ # apply wiener filtering from OpenUnmix.
+ init = mix_stft.dtype
+ wiener_win_len = 300
+ residual = self.wiener_residual
+
+ B, S, C, Fq, T = mag_out.shape
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
+
+ outs = []
+ for sample in range(B):
+ pos = 0
+ out = []
+ for pos in range(0, T, wiener_win_len):
+ frame = slice(pos, pos + wiener_win_len)
+ z_out = wiener(
+ mag_out[sample, frame], mix_stft[sample, frame], niters,
+ residual=residual)
+ out.append(z_out.transpose(-1, -2))
+ outs.append(torch.cat(out, dim=0))
+ out = torch.view_as_complex(torch.stack(outs, 0))
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
+ if residual:
+ out = out[:, :-1]
+ assert list(out.shape) == [B, S, C, Fq, T]
+ return out.to(init)
+
+ def forward(self, mix):
+ x = mix
+ length = x.shape[-1]
+
+ z = self._spec(mix)
+ mag = self._magnitude(z).to(mix.device)
+ x = mag
+
+ B, C, Fq, T = x.shape
+
+ # unlike previous Demucs, we always normalize because it is easier.
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
+ std = x.std(dim=(1, 2, 3), keepdim=True)
+ x = (x - mean) / (1e-5 + std)
+ # x will be the freq. branch input.
+
+ if self.hybrid:
+ # Prepare the time branch input.
+ xt = mix
+ meant = xt.mean(dim=(1, 2), keepdim=True)
+ stdt = xt.std(dim=(1, 2), keepdim=True)
+ xt = (xt - meant) / (1e-5 + stdt)
+
+ # okay, this is a giant mess I know...
+ saved = [] # skip connections, freq.
+ saved_t = [] # skip connections, time.
+ lengths = [] # saved lengths to properly remove padding, freq branch.
+ lengths_t = [] # saved lengths for time branch.
+ for idx, encode in enumerate(self.encoder):
+ lengths.append(x.shape[-1])
+ inject = None
+ if self.hybrid and idx < len(self.tencoder):
+ # we have not yet merged branches.
+ lengths_t.append(xt.shape[-1])
+ tenc = self.tencoder[idx]
+ xt = tenc(xt)
+ if not tenc.empty:
+ # save for skip connection
+ saved_t.append(xt)
+ else:
+ # tenc contains just the first conv., so that now time and freq.
+ # branches have the same shape and can be merged.
+ inject = xt
+ x = encode(x, inject)
+ if idx == 0 and self.freq_emb is not None:
+ # add frequency embedding to allow for non equivariant convolutions
+ # over the frequency axis.
+ frs = torch.arange(x.shape[-2], device=x.device)
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
+ x = x + self.freq_emb_scale * emb
+
+ saved.append(x)
+
+ x = torch.zeros_like(x)
+ if self.hybrid:
+ xt = torch.zeros_like(x)
+ # initialize everything to zero (signal will go through u-net skips).
+
+ for idx, decode in enumerate(self.decoder):
+ skip = saved.pop(-1)
+ x, pre = decode(x, skip, lengths.pop(-1))
+ # `pre` contains the output just before final transposed convolution,
+ # which is used when the freq. and time branch separate.
+
+ if self.hybrid:
+ offset = self.depth - len(self.tdecoder)
+ if self.hybrid and idx >= offset:
+ tdec = self.tdecoder[idx - offset]
+ length_t = lengths_t.pop(-1)
+ if tdec.empty:
+ assert pre.shape[2] == 1, pre.shape
+ pre = pre[:, :, 0]
+ xt, _ = tdec(pre, None, length_t)
+ else:
+ skip = saved_t.pop(-1)
+ xt, _ = tdec(xt, skip, length_t)
+
+ # Let's make sure we used all stored skip connections.
+ assert len(saved) == 0
+ assert len(lengths_t) == 0
+ assert len(saved_t) == 0
+
+ S = len(self.sources)
+ x = x.view(B, S, -1, Fq, T)
+ x = x * std[:, None] + mean[:, None]
+
+ # to cpu as mps doesnt support complex numbers
+ # demucs issue #435 ##432
+ # NOTE: in this case z already is on cpu
+ # TODO: remove this when mps supports complex numbers
+ x_is_mps_xpu = x.device.type in ["mps", "xpu"]
+ x_device = x.device
+ if x_is_mps_xpu:
+ x = x.cpu()
+
+ zout = self._mask(z, x)
+ x = self._ispec(zout, length)
+
+ # back to mps device
+ if x_is_mps_xpu:
+ x = x.to(x_device)
+
+
+ if self.hybrid:
+ xt = xt.view(B, S, -1, length)
+ xt = xt * stdt[:, None] + meant[:, None]
+ x = xt + x
+ return x
diff --git a/demucs/htdemucs.py b/demucs/htdemucs.py
new file mode 100644
index 00000000..56568608
--- /dev/null
+++ b/demucs/htdemucs.py
@@ -0,0 +1,661 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# First author is Simon Rouard.
+"""
+This code contains the spectrogram and Hybrid version of Demucs.
+"""
+import math
+
+from openunmix.filtering import wiener
+import torch
+from torch import nn
+from torch.nn import functional as F
+from fractions import Fraction
+from einops import rearrange
+
+from .transformer import CrossTransformerEncoder
+
+from .demucs import rescale_module
+from .states import capture_init
+from .spec import spectro, ispectro
+from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
+
+
+class HTDemucs(nn.Module):
+ """
+ Spectrogram and hybrid Demucs model.
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
+ Frequency layers can still access information across time steps thanks to the DConv residual.
+
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
+
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
+ Open Unmix implementation [Stoter et al. 2019].
+
+ The loss is always on the temporal domain, by backpropagating through the above
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
+ contribution, without changing the one from the waveform, which will lead to worse performance.
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
+ hybrid models.
+
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
+
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
+ """
+
+ @capture_init
+ def __init__(
+ self,
+ sources,
+ # Channels
+ audio_channels=2,
+ channels=48,
+ channels_time=None,
+ growth=2,
+ # STFT
+ nfft=4096,
+ wiener_iters=0,
+ end_iters=0,
+ wiener_residual=False,
+ cac=True,
+ # Main structure
+ depth=4,
+ rewrite=True,
+ # Frequency branch
+ multi_freqs=None,
+ multi_freqs_depth=3,
+ freq_emb=0.2,
+ emb_scale=10,
+ emb_smooth=True,
+ # Convolutions
+ kernel_size=8,
+ time_stride=2,
+ stride=4,
+ context=1,
+ context_enc=0,
+ # Normalization
+ norm_starts=4,
+ norm_groups=4,
+ # DConv residual branch
+ dconv_mode=1,
+ dconv_depth=2,
+ dconv_comp=8,
+ dconv_init=1e-3,
+ # Before the Transformer
+ bottom_channels=0,
+ # Transformer
+ t_layers=5,
+ t_emb="sin",
+ t_hidden_scale=4.0,
+ t_heads=8,
+ t_dropout=0.0,
+ t_max_positions=10000,
+ t_norm_in=True,
+ t_norm_in_group=False,
+ t_group_norm=False,
+ t_norm_first=True,
+ t_norm_out=True,
+ t_max_period=10000.0,
+ t_weight_decay=0.0,
+ t_lr=None,
+ t_layer_scale=True,
+ t_gelu=True,
+ t_weight_pos_embed=1.0,
+ t_sin_random_shift=0,
+ t_cape_mean_normalize=True,
+ t_cape_augment=True,
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
+ t_sparse_self_attn=False,
+ t_sparse_cross_attn=False,
+ t_mask_type="diag",
+ t_mask_random_seed=42,
+ t_sparse_attn_window=500,
+ t_global_window=100,
+ t_sparsity=0.95,
+ t_auto_sparsity=False,
+ # ------ Particuliar parameters
+ t_cross_first=False,
+ # Weight init
+ rescale=0.1,
+ # Metadata
+ samplerate=44100,
+ segment=10,
+ use_train_segment=True,
+ ):
+ """
+ Args:
+ sources (list[str]): list of source names.
+ audio_channels (int): input/output audio channels.
+ channels (int): initial number of hidden channels.
+ channels_time: if not None, use a different `channels` value for the time branch.
+ growth: increase the number of hidden channels by this factor at each layer.
+ nfft: number of fft bins. Note that changing this require careful computation of
+ various shape parameters and will not work out of the box for hybrid models.
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
+ wiener_residual: add residual source before wiener filtering.
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
+ in input and output. no further processing is done before ISTFT.
+ depth (int): number of layers in the encoder and in the decoder.
+ rewrite (bool): add 1x1 convolution to each layer.
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
+ layers will be wrapped.
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
+ the actual value controls the weight of the embedding.
+ emb_scale: equivalent to scaling the embedding learning rate
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
+ kernel_size: kernel_size for encoder and decoder layers.
+ stride: stride for encoder and decoder layers.
+ time_stride: stride for the final time layer, after the merge.
+ context: context for 1x1 conv in the decoder.
+ context_enc: context for 1x1 conv in the encoder.
+ norm_starts: layer at which group norm starts being used.
+ decoder layers are numbered in reverse order.
+ norm_groups: number of groups for group norm.
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
+ dconv_depth: depth of residual DConv branch.
+ dconv_comp: compression of DConv branch.
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
+ dconv_init: initial scale for the DConv branch LayerScale.
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
+ transformer in order to change the number of channels
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
+ t_emb: "sin", "cape" or "scaled"
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
+ for instance if C = 384 (the number of channels in the transformer) and
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
+ 384 * 4 = 1536
+ t_heads: number of heads for the transformer
+ t_dropout: dropout in the transformer
+ t_max_positions: max_positions for the "scaled" positional embedding, only
+ useful if t_emb="scaled"
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
+ transformer layers
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
+ timesteps (GroupNorm with group=1)
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
+ timesteps (GroupNorm with group=1)
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
+ t_max_period: (float) denominator in the sinusoidal embedding expression
+ t_weight_decay: (float) weight decay for the transformer
+ t_lr: (float) specific learning rate for the transformer
+ t_layer_scale: (bool) Layer Scale for the transformer
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
+ t_weight_pos_embed: (float) weighting of the positional embedding
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
+ see: https://arxiv.org/abs/2106.03143
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
+ during the inference, see: https://arxiv.org/abs/2106.03143
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
+ see: https://arxiv.org/abs/2106.03143
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
+ unless you designed really specific masks)
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
+ that generated the random part of the mask
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
+ and mask[:, :t_global_window] will be True
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
+ level of the random part of the mask.
+ t_cross_first: (bool) if True cross attention is the first layer of the
+ transformer (False seems to be better)
+ rescale: weight rescaling trick
+ use_train_segment: (bool) if True, the actual size that is used during the
+ training is used during inference.
+ """
+ super().__init__()
+ self.cac = cac
+ self.wiener_residual = wiener_residual
+ self.audio_channels = audio_channels
+ self.sources = sources
+ self.kernel_size = kernel_size
+ self.context = context
+ self.stride = stride
+ self.depth = depth
+ self.bottom_channels = bottom_channels
+ self.channels = channels
+ self.samplerate = samplerate
+ self.segment = segment
+ self.use_train_segment = use_train_segment
+ self.nfft = nfft
+ self.hop_length = nfft // 4
+ self.wiener_iters = wiener_iters
+ self.end_iters = end_iters
+ self.freq_emb = None
+ assert wiener_iters == end_iters
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ self.tencoder = nn.ModuleList()
+ self.tdecoder = nn.ModuleList()
+
+ chin = audio_channels
+ chin_z = chin # number of channels for the freq branch
+ if self.cac:
+ chin_z *= 2
+ chout = channels_time or channels
+ chout_z = channels
+ freqs = nfft // 2
+
+ for index in range(depth):
+ norm = index >= norm_starts
+ freq = freqs > 1
+ stri = stride
+ ker = kernel_size
+ if not freq:
+ assert freqs == 1
+ ker = time_stride * 2
+ stri = time_stride
+
+ pad = True
+ last_freq = False
+ if freq and freqs <= kernel_size:
+ ker = freqs
+ pad = False
+ last_freq = True
+
+ kw = {
+ "kernel_size": ker,
+ "stride": stri,
+ "freq": freq,
+ "pad": pad,
+ "norm": norm,
+ "rewrite": rewrite,
+ "norm_groups": norm_groups,
+ "dconv_kw": {
+ "depth": dconv_depth,
+ "compress": dconv_comp,
+ "init": dconv_init,
+ "gelu": True,
+ },
+ }
+ kwt = dict(kw)
+ kwt["freq"] = 0
+ kwt["kernel_size"] = kernel_size
+ kwt["stride"] = stride
+ kwt["pad"] = True
+ kw_dec = dict(kw)
+ multi = False
+ if multi_freqs and index < multi_freqs_depth:
+ multi = True
+ kw_dec["context_freq"] = False
+
+ if last_freq:
+ chout_z = max(chout, chout_z)
+ chout = chout_z
+
+ enc = HEncLayer(
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
+ )
+ if freq:
+ tenc = HEncLayer(
+ chin,
+ chout,
+ dconv=dconv_mode & 1,
+ context=context_enc,
+ empty=last_freq,
+ **kwt
+ )
+ self.tencoder.append(tenc)
+
+ if multi:
+ enc = MultiWrap(enc, multi_freqs)
+ self.encoder.append(enc)
+ if index == 0:
+ chin = self.audio_channels * len(self.sources)
+ chin_z = chin
+ if self.cac:
+ chin_z *= 2
+ dec = HDecLayer(
+ chout_z,
+ chin_z,
+ dconv=dconv_mode & 2,
+ last=index == 0,
+ context=context,
+ **kw_dec
+ )
+ if multi:
+ dec = MultiWrap(dec, multi_freqs)
+ if freq:
+ tdec = HDecLayer(
+ chout,
+ chin,
+ dconv=dconv_mode & 2,
+ empty=last_freq,
+ last=index == 0,
+ context=context,
+ **kwt
+ )
+ self.tdecoder.insert(0, tdec)
+ self.decoder.insert(0, dec)
+
+ chin = chout
+ chin_z = chout_z
+ chout = int(growth * chout)
+ chout_z = int(growth * chout_z)
+ if freq:
+ if freqs <= kernel_size:
+ freqs = 1
+ else:
+ freqs //= stride
+ if index == 0 and freq_emb:
+ self.freq_emb = ScaledEmbedding(
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
+ )
+ self.freq_emb_scale = freq_emb
+
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ transformer_channels = channels * growth ** (depth - 1)
+ if bottom_channels:
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
+ self.channel_downsampler = nn.Conv1d(
+ bottom_channels, transformer_channels, 1
+ )
+ self.channel_upsampler_t = nn.Conv1d(
+ transformer_channels, bottom_channels, 1
+ )
+ self.channel_downsampler_t = nn.Conv1d(
+ bottom_channels, transformer_channels, 1
+ )
+
+ transformer_channels = bottom_channels
+
+ if t_layers > 0:
+ self.crosstransformer = CrossTransformerEncoder(
+ dim=transformer_channels,
+ emb=t_emb,
+ hidden_scale=t_hidden_scale,
+ num_heads=t_heads,
+ num_layers=t_layers,
+ cross_first=t_cross_first,
+ dropout=t_dropout,
+ max_positions=t_max_positions,
+ norm_in=t_norm_in,
+ norm_in_group=t_norm_in_group,
+ group_norm=t_group_norm,
+ norm_first=t_norm_first,
+ norm_out=t_norm_out,
+ max_period=t_max_period,
+ weight_decay=t_weight_decay,
+ lr=t_lr,
+ layer_scale=t_layer_scale,
+ gelu=t_gelu,
+ sin_random_shift=t_sin_random_shift,
+ weight_pos_embed=t_weight_pos_embed,
+ cape_mean_normalize=t_cape_mean_normalize,
+ cape_augment=t_cape_augment,
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
+ sparse_self_attn=t_sparse_self_attn,
+ sparse_cross_attn=t_sparse_cross_attn,
+ mask_type=t_mask_type,
+ mask_random_seed=t_mask_random_seed,
+ sparse_attn_window=t_sparse_attn_window,
+ global_window=t_global_window,
+ sparsity=t_sparsity,
+ auto_sparsity=t_auto_sparsity,
+ )
+ else:
+ self.crosstransformer = None
+
+ def _spec(self, x):
+ hl = self.hop_length
+ nfft = self.nfft
+ x0 = x # noqa
+
+ # We re-pad the signal in order to keep the property
+ # that the size of the output is exactly the size of the input
+ # divided by the stride (here hop_length), when divisible.
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
+ # which is not supported by torch.stft.
+ # Having all convolution operations follow this convention allow to easily
+ # align the time and frequency branches later on.
+ assert hl == nfft // 4
+ le = int(math.ceil(x.shape[-1] / hl))
+ pad = hl // 2 * 3
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
+
+ z = spectro(x, nfft, hl)[..., :-1, :]
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
+ z = z[..., 2: 2 + le]
+ return z
+
+ def _ispec(self, z, length=None, scale=0):
+ hl = self.hop_length // (4**scale)
+ z = F.pad(z, (0, 0, 0, 1))
+ z = F.pad(z, (2, 2))
+ pad = hl // 2 * 3
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
+ x = ispectro(z, hl, length=le)
+ x = x[..., pad: pad + length]
+ return x
+
+ def _magnitude(self, z):
+ # return the magnitude of the spectrogram, except when cac is True,
+ # in which case we just move the complex dimension to the channel one.
+ if self.cac:
+ B, C, Fr, T = z.shape
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
+ m = m.reshape(B, C * 2, Fr, T)
+ else:
+ m = z.abs()
+ return m
+
+ def _mask(self, z, m):
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
+ niters = self.wiener_iters
+ if self.cac:
+ B, S, C, Fr, T = m.shape
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
+ out = torch.view_as_complex(out.contiguous())
+ return out
+ if self.training:
+ niters = self.end_iters
+ if niters < 0:
+ z = z[:, None]
+ return z / (1e-8 + z.abs()) * m
+ else:
+ return self._wiener(m, z, niters)
+
+ def _wiener(self, mag_out, mix_stft, niters):
+ # apply wiener filtering from OpenUnmix.
+ init = mix_stft.dtype
+ wiener_win_len = 300
+ residual = self.wiener_residual
+
+ B, S, C, Fq, T = mag_out.shape
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
+
+ outs = []
+ for sample in range(B):
+ pos = 0
+ out = []
+ for pos in range(0, T, wiener_win_len):
+ frame = slice(pos, pos + wiener_win_len)
+ z_out = wiener(
+ mag_out[sample, frame],
+ mix_stft[sample, frame],
+ niters,
+ residual=residual,
+ )
+ out.append(z_out.transpose(-1, -2))
+ outs.append(torch.cat(out, dim=0))
+ out = torch.view_as_complex(torch.stack(outs, 0))
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
+ if residual:
+ out = out[:, :-1]
+ assert list(out.shape) == [B, S, C, Fq, T]
+ return out.to(init)
+
+ def valid_length(self, length: int):
+ """
+ Return a length that is appropriate for evaluation.
+ In our case, always return the training length, unless
+ it is smaller than the given length, in which case this
+ raises an error.
+ """
+ if not self.use_train_segment:
+ return length
+ training_length = int(self.segment * self.samplerate)
+ if training_length < length:
+ raise ValueError(
+ f"Given length {length} is longer than "
+ f"training length {training_length}")
+ return training_length
+
+ def forward(self, mix):
+ length = mix.shape[-1]
+ length_pre_pad = None
+ if self.use_train_segment:
+ if self.training:
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
+ else:
+ training_length = int(self.segment * self.samplerate)
+ if mix.shape[-1] < training_length:
+ length_pre_pad = mix.shape[-1]
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
+ z = self._spec(mix)
+ mag = self._magnitude(z).to(mix.device)
+ x = mag
+
+ B, C, Fq, T = x.shape
+
+ # unlike previous Demucs, we always normalize because it is easier.
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
+ std = x.std(dim=(1, 2, 3), keepdim=True)
+ x = (x - mean) / (1e-5 + std)
+ # x will be the freq. branch input.
+
+ # Prepare the time branch input.
+ xt = mix
+ meant = xt.mean(dim=(1, 2), keepdim=True)
+ stdt = xt.std(dim=(1, 2), keepdim=True)
+ xt = (xt - meant) / (1e-5 + stdt)
+
+ # okay, this is a giant mess I know...
+ saved = [] # skip connections, freq.
+ saved_t = [] # skip connections, time.
+ lengths = [] # saved lengths to properly remove padding, freq branch.
+ lengths_t = [] # saved lengths for time branch.
+ for idx, encode in enumerate(self.encoder):
+ lengths.append(x.shape[-1])
+ inject = None
+ if idx < len(self.tencoder):
+ # we have not yet merged branches.
+ lengths_t.append(xt.shape[-1])
+ tenc = self.tencoder[idx]
+ xt = tenc(xt)
+ if not tenc.empty:
+ # save for skip connection
+ saved_t.append(xt)
+ else:
+ # tenc contains just the first conv., so that now time and freq.
+ # branches have the same shape and can be merged.
+ inject = xt
+ x = encode(x, inject)
+ if idx == 0 and self.freq_emb is not None:
+ # add frequency embedding to allow for non equivariant convolutions
+ # over the frequency axis.
+ frs = torch.arange(x.shape[-2], device=x.device)
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
+ x = x + self.freq_emb_scale * emb
+
+ saved.append(x)
+ if self.crosstransformer:
+ if self.bottom_channels:
+ b, c, f, t = x.shape
+ x = rearrange(x, "b c f t-> b c (f t)")
+ x = self.channel_upsampler(x)
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
+ xt = self.channel_upsampler_t(xt)
+
+ x, xt = self.crosstransformer(x, xt)
+
+ if self.bottom_channels:
+ x = rearrange(x, "b c f t-> b c (f t)")
+ x = self.channel_downsampler(x)
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
+ xt = self.channel_downsampler_t(xt)
+
+ for idx, decode in enumerate(self.decoder):
+ skip = saved.pop(-1)
+ x, pre = decode(x, skip, lengths.pop(-1))
+ # `pre` contains the output just before final transposed convolution,
+ # which is used when the freq. and time branch separate.
+
+ offset = self.depth - len(self.tdecoder)
+ if idx >= offset:
+ tdec = self.tdecoder[idx - offset]
+ length_t = lengths_t.pop(-1)
+ if tdec.empty:
+ assert pre.shape[2] == 1, pre.shape
+ pre = pre[:, :, 0]
+ xt, _ = tdec(pre, None, length_t)
+ else:
+ skip = saved_t.pop(-1)
+ xt, _ = tdec(xt, skip, length_t)
+
+ # Let's make sure we used all stored skip connections.
+ assert len(saved) == 0
+ assert len(lengths_t) == 0
+ assert len(saved_t) == 0
+
+ S = len(self.sources)
+ x = x.view(B, S, -1, Fq, T)
+ x = x * std[:, None] + mean[:, None]
+
+ # to cpu as mps doesnt support complex numbers
+ # demucs issue #435 ##432
+ # NOTE: in this case z already is on cpu
+ # TODO: remove this when mps supports complex numbers
+ x_is_mps_xpu = x.device.type in ["mps", "xpu"]
+ x_device = x.device
+ if x_is_mps_xpu:
+ x = x.cpu()
+
+ zout = self._mask(z, x)
+ if self.use_train_segment:
+ if self.training:
+ x = self._ispec(zout, length)
+ else:
+ x = self._ispec(zout, training_length)
+ else:
+ x = self._ispec(zout, length)
+
+ # back to mps device
+ if x_is_mps_xpu:
+ x = x.to(x_device)
+
+ if self.use_train_segment:
+ if self.training:
+ xt = xt.view(B, S, -1, length)
+ else:
+ xt = xt.view(B, S, -1, training_length)
+ else:
+ xt = xt.view(B, S, -1, length)
+ xt = xt * stdt[:, None] + meant[:, None]
+ x = xt + x
+ if length_pre_pad:
+ x = x[..., :length_pre_pad]
+ return x
diff --git a/demucs/pretrained.py b/demucs/pretrained.py
new file mode 100644
index 00000000..80ae49cb
--- /dev/null
+++ b/demucs/pretrained.py
@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Loading pretrained models.
+"""
+
+import logging
+from pathlib import Path
+import typing as tp
+
+from dora.log import fatal, bold
+
+from .hdemucs import HDemucs
+from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
+from .states import _check_diffq
+
+logger = logging.getLogger(__name__)
+ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/"
+REMOTE_ROOT = Path(__file__).parent / 'remote'
+
+SOURCES = ["drums", "bass", "other", "vocals"]
+DEFAULT_MODEL = 'htdemucs'
+
+
+def demucs_unittest():
+ model = HDemucs(channels=4, sources=SOURCES)
+ return model
+
+
+def add_model_flags(parser):
+ group = parser.add_mutually_exclusive_group(required=False)
+ group.add_argument("-s", "--sig", help="Locally trained XP signature.")
+ group.add_argument("-n", "--name", default="htdemucs",
+ help="Pretrained model name or signature. Default is htdemucs.")
+ parser.add_argument("--repo", type=Path,
+ help="Folder containing all pre-trained models for use with -n.")
+
+
+def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
+ root: str = ''
+ models: tp.Dict[str, str] = {}
+ for line in remote_file_list.read_text().split('\n'):
+ line = line.strip()
+ if line.startswith('#'):
+ continue
+ elif len(line) == 0:
+ continue
+ elif line.startswith('root:'):
+ root = line.split(':', 1)[1].strip()
+ else:
+ sig = line.split('-', 1)[0]
+ assert sig not in models
+ models[sig] = ROOT_URL + root + line
+ return models
+
+
+def get_model(name: str,
+ repo: tp.Optional[Path] = None):
+ """`name` must be a bag of models name or a pretrained signature
+ from the remote AWS model repo or the specified local repo if `repo` is not None.
+ """
+ if name == 'demucs_unittest':
+ return demucs_unittest()
+ model_repo: ModelOnlyRepo
+ if repo is None:
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
+ model_repo = RemoteRepo(models)
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
+ else:
+ if not repo.is_dir():
+ fatal(f"{repo} must exist and be a directory.")
+ model_repo = LocalRepo(repo)
+ bag_repo = BagOnlyRepo(repo, model_repo)
+ any_repo = AnyModelRepo(model_repo, bag_repo)
+ try:
+ model = any_repo.get_model(name)
+ except ImportError as exc:
+ if 'diffq' in exc.args[0]:
+ _check_diffq()
+ raise
+
+ model.eval()
+ return model
+
+
+def get_model_from_args(args):
+ """
+ Load local model package or pre-trained model.
+ """
+ if args.name is None:
+ args.name = DEFAULT_MODEL
+ print(bold("Important: the default model was recently changed to `htdemucs`"),
+ "the latest Hybrid Transformer Demucs model. In some cases, this model can "
+ "actually perform worse than previous models. To get back the old default model "
+ "use `-n mdx_extra_q`.")
+ return get_model(name=args.name, repo=args.repo)
diff --git a/demucs/py.typed b/demucs/py.typed
new file mode 100644
index 00000000..e69de29b
diff --git a/demucs/repitch.py b/demucs/repitch.py
new file mode 100644
index 00000000..b69c0d25
--- /dev/null
+++ b/demucs/repitch.py
@@ -0,0 +1,87 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Utility for on the fly pitch/tempo change for data augmentation."""
+
+import random
+import subprocess as sp
+import tempfile
+
+from . import audio_legacy
+import torch
+import torchaudio as ta
+
+from .audio import save_audio
+
+
+class RepitchedWrapper:
+ """
+ Wrap a dataset to apply online change of pitch / tempo.
+ """
+ def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12,
+ tempo_std=5, vocals=[3], same=True):
+ self.dataset = dataset
+ self.proba = proba
+ self.max_pitch = max_pitch
+ self.max_tempo = max_tempo
+ self.tempo_std = tempo_std
+ self.same = same
+ self.vocals = vocals
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+ streams = self.dataset[index]
+ in_length = streams.shape[-1]
+ out_length = int((1 - 0.01 * self.max_tempo) * in_length)
+
+ if random.random() < self.proba:
+ outs = []
+ for idx, stream in enumerate(streams):
+ if idx == 0 or not self.same:
+ delta_pitch = random.randint(-self.max_pitch, self.max_pitch)
+ delta_tempo = random.gauss(0, self.tempo_std)
+ delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo)
+ stream = repitch(
+ stream,
+ delta_pitch,
+ delta_tempo,
+ voice=idx in self.vocals)
+ outs.append(stream[:, :out_length])
+ streams = torch.stack(outs)
+ else:
+ streams = streams[..., :out_length]
+ return streams
+
+
+def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
+ """
+ tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
+ pitch is in semi tones.
+ Requires `soundstretch` to be installed, see
+ https://www.surina.net/soundtouch/soundstretch.html
+ """
+ infile = tempfile.NamedTemporaryFile(suffix=".wav")
+ outfile = tempfile.NamedTemporaryFile(suffix=".wav")
+ save_audio(wav, infile.name, samplerate, clip='clamp')
+ command = [
+ "soundstretch",
+ infile.name,
+ outfile.name,
+ f"-pitch={pitch}",
+ f"-tempo={tempo:.6f}",
+ ]
+ if quick:
+ command += ["-quick"]
+ if voice:
+ command += ["-speech"]
+ try:
+ sp.run(command, capture_output=True, check=True)
+ except sp.CalledProcessError as error:
+ raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
+ wav, sr = ta.load(outfile.name)
+ assert sr == samplerate
+ return wav
diff --git a/demucs/repo.py b/demucs/repo.py
new file mode 100644
index 00000000..5e20ff51
--- /dev/null
+++ b/demucs/repo.py
@@ -0,0 +1,166 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Represents a model repository, including pre-trained models and bags of models.
+A repo can either be the main remote repository stored in AWS, or a local repository
+with your own models.
+"""
+
+from hashlib import sha256
+from pathlib import Path
+import typing as tp
+
+import torch
+import yaml
+
+from .apply import BagOfModels, Model
+from .states import load_model
+
+
+AnyModel = tp.Union[Model, BagOfModels]
+
+
+class ModelLoadingError(RuntimeError):
+ pass
+
+
+def check_checksum(path: Path, checksum: str):
+ sha = sha256()
+ with open(path, 'rb') as file:
+ while True:
+ buf = file.read(2**20)
+ if not buf:
+ break
+ sha.update(buf)
+ actual_checksum = sha.hexdigest()[:len(checksum)]
+ if actual_checksum != checksum:
+ raise ModelLoadingError(f'Invalid checksum for file {path}, '
+ f'expected {checksum} but got {actual_checksum}')
+
+
+class ModelOnlyRepo:
+ """Base class for all model only repos.
+ """
+ def has_model(self, sig: str) -> bool:
+ raise NotImplementedError()
+
+ def get_model(self, sig: str) -> Model:
+ raise NotImplementedError()
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ raise NotImplementedError()
+
+
+class RemoteRepo(ModelOnlyRepo):
+ def __init__(self, models: tp.Dict[str, str]):
+ self._models = models
+
+ def has_model(self, sig: str) -> bool:
+ return sig in self._models
+
+ def get_model(self, sig: str) -> Model:
+ try:
+ url = self._models[sig]
+ except KeyError:
+ raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
+ pkg = torch.hub.load_state_dict_from_url(
+ url, map_location='cpu', check_hash=True) # type: ignore
+ return load_model(pkg)
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ return self._models # type: ignore
+
+
+class LocalRepo(ModelOnlyRepo):
+ def __init__(self, root: Path):
+ self.root = root
+ self.scan()
+
+ def scan(self):
+ self._models = {}
+ self._checksums = {}
+ for file in self.root.iterdir():
+ if file.suffix == '.th':
+ if '-' in file.stem:
+ xp_sig, checksum = file.stem.split('-')
+ self._checksums[xp_sig] = checksum
+ else:
+ xp_sig = file.stem
+ if xp_sig in self._models:
+ raise ModelLoadingError(
+ f'Duplicate pre-trained model exist for signature {xp_sig}. '
+ 'Please delete all but one.')
+ self._models[xp_sig] = file
+
+ def has_model(self, sig: str) -> bool:
+ return sig in self._models
+
+ def get_model(self, sig: str) -> Model:
+ try:
+ file = self._models[sig]
+ except KeyError:
+ raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
+ if sig in self._checksums:
+ check_checksum(file, self._checksums[sig])
+ return load_model(file)
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ return self._models
+
+
+class BagOnlyRepo:
+ """Handles only YAML files containing bag of models, leaving the actual
+ model loading to some Repo.
+ """
+ def __init__(self, root: Path, model_repo: ModelOnlyRepo):
+ self.root = root
+ self.model_repo = model_repo
+ self.scan()
+
+ def scan(self):
+ self._bags = {}
+ for file in self.root.iterdir():
+ if file.suffix == '.yaml':
+ self._bags[file.stem] = file
+
+ def has_model(self, name: str) -> bool:
+ return name in self._bags
+
+ def get_model(self, name: str) -> BagOfModels:
+ try:
+ yaml_file = self._bags[name]
+ except KeyError:
+ raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
+ 'a bag of models.')
+ bag = yaml.safe_load(open(yaml_file))
+ signatures = bag['models']
+ models = [self.model_repo.get_model(sig) for sig in signatures]
+ weights = bag.get('weights')
+ segment = bag.get('segment')
+ return BagOfModels(models, weights, segment)
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ return self._bags
+
+
+class AnyModelRepo:
+ def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
+ self.model_repo = model_repo
+ self.bag_repo = bag_repo
+
+ def has_model(self, name_or_sig: str) -> bool:
+ return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
+
+ def get_model(self, name_or_sig: str) -> AnyModel:
+ if self.model_repo.has_model(name_or_sig):
+ return self.model_repo.get_model(name_or_sig)
+ else:
+ return self.bag_repo.get_model(name_or_sig)
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ models = self.model_repo.list_model()
+ for key, value in self.bag_repo.list_model().items():
+ models[key] = value
+ return models
diff --git a/demucs/separate.py b/demucs/separate.py
new file mode 100644
index 00000000..7de5f114
--- /dev/null
+++ b/demucs/separate.py
@@ -0,0 +1,228 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import sys
+from pathlib import Path
+
+from dora.log import fatal
+import torch as th
+
+from .api import Separator, save_audio, list_models
+
+from .apply import BagOfModels
+from .htdemucs import HTDemucs
+from .pretrained import add_model_flags, ModelLoadingError
+
+
+def get_parser():
+ parser = argparse.ArgumentParser("demucs.separate",
+ description="Separate the sources for the given tracks")
+ parser.add_argument("tracks", nargs='*', type=Path, default=[], help='Path to tracks')
+ add_model_flags(parser)
+ parser.add_argument("--list-models", action="store_true", help="List available models "
+ "from current repo and exit")
+ parser.add_argument("-v", "--verbose", action="store_true")
+ parser.add_argument("-o",
+ "--out",
+ type=Path,
+ default=Path("separated"),
+ help="Folder where to put extracted tracks. A subfolder "
+ "with the model name will be created.")
+ parser.add_argument("--filename",
+ default="{track}/{stem}.{ext}",
+ help="Set the name of output file. \n"
+ 'Use "{track}", "{trackext}", "{stem}", "{ext}" to use '
+ "variables of track name without extension, track extension, "
+ "stem name and default output file extension. \n"
+ 'Default is "{track}/{stem}.{ext}".')
+ parser.add_argument("-d",
+ "--device",
+ default=(
+ "cuda"
+ if th.cuda.is_available()
+ else "mps"
+ if th.backends.mps.is_available()
+ else "cpu"
+ ),
+ help="Device to use, default is cuda if available else cpu")
+ parser.add_argument("--shifts",
+ default=1,
+ type=int,
+ help="Number of random shifts for equivariant stabilization."
+ "Increase separation time but improves quality for Demucs. 10 was used "
+ "in the original paper.")
+ parser.add_argument("--overlap",
+ default=0.25,
+ type=float,
+ help="Overlap between the splits.")
+ split_group = parser.add_mutually_exclusive_group()
+ split_group.add_argument("--no-split",
+ action="store_false",
+ dest="split",
+ default=True,
+ help="Doesn't split audio in chunks. "
+ "This can use large amounts of memory.")
+ split_group.add_argument("--segment", type=int,
+ help="Set split size of each chunk. "
+ "This can help save memory of graphic card. ")
+ parser.add_argument("--two-stems",
+ dest="stem", metavar="STEM",
+ help="Only separate audio into {STEM} and no_{STEM}. ")
+ parser.add_argument("--other-method", dest="other_method", choices=["none", "add", "minus"],
+ default="add", help='Decide how to get "no_{STEM}". "none" will not save '
+ '"no_{STEM}". "add" will add all the other stems. "minus" will use the '
+ "original track minus the selected stem.")
+ depth_group = parser.add_mutually_exclusive_group()
+ depth_group.add_argument("--int24", action="store_true",
+ help="Save wav output as 24 bits wav.")
+ depth_group.add_argument("--float32", action="store_true",
+ help="Save wav output as float32 (2x bigger).")
+ parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp", "none"],
+ help="Strategy for avoiding clipping: rescaling entire signal "
+ "if necessary (rescale) or hard clipping (clamp).")
+ format_group = parser.add_mutually_exclusive_group()
+ format_group.add_argument("--flac", action="store_true",
+ help="Convert the output wavs to flac.")
+ format_group.add_argument("--mp3", action="store_true",
+ help="Convert the output wavs to mp3.")
+ parser.add_argument("--mp3-bitrate",
+ default=320,
+ type=int,
+ help="Bitrate of converted mp3.")
+ parser.add_argument("--mp3-preset", choices=range(2, 8), type=int, default=2,
+ help="Encoder preset of MP3, 2 for highest quality, 7 for "
+ "fastest speed. Default is 2")
+ parser.add_argument("-j", "--jobs",
+ default=0,
+ type=int,
+ help="Number of jobs. This can increase memory usage but will "
+ "be much faster when multiple cores are available.")
+
+ return parser
+
+
+def main(opts=None):
+ parser = get_parser()
+ args = parser.parse_args(opts)
+ if args.list_models:
+ models = list_models(args.repo)
+ print("Bag of models:", end="\n ")
+ print("\n ".join(models["bag"]))
+ print("Single models:", end="\n ")
+ print("\n ".join(models["single"]))
+ sys.exit(0)
+ if len(args.tracks) == 0:
+ print("error: the following arguments are required: tracks", file=sys.stderr)
+ sys.exit(1)
+
+ try:
+ separator = Separator(model=args.name,
+ repo=args.repo,
+ device=args.device,
+ shifts=args.shifts,
+ split=args.split,
+ overlap=args.overlap,
+ progress=True,
+ jobs=args.jobs,
+ segment=args.segment)
+ except ModelLoadingError as error:
+ fatal(error.args[0])
+
+ max_allowed_segment = float('inf')
+ if isinstance(separator.model, HTDemucs):
+ max_allowed_segment = float(separator.model.segment)
+ elif isinstance(separator.model, BagOfModels):
+ max_allowed_segment = separator.model.max_allowed_segment
+ if args.segment is not None and args.segment > max_allowed_segment:
+ fatal("Cannot use a Transformer model with a longer segment "
+ f"than it was trained for. Maximum segment is: {max_allowed_segment}")
+
+ if isinstance(separator.model, BagOfModels):
+ print(
+ f"Selected model is a bag of {len(separator.model.models)} models. "
+ "You will see that many progress bars per track."
+ )
+
+ if args.stem is not None and args.stem not in separator.model.sources:
+ fatal(
+ 'error: stem "{stem}" is not in selected model. '
+ "STEM must be one of {sources}.".format(
+ stem=args.stem, sources=", ".join(separator.model.sources)
+ )
+ )
+ out = args.out / args.name
+ out.mkdir(parents=True, exist_ok=True)
+ print(f"Separated tracks will be stored in {out.resolve()}")
+ for track in args.tracks:
+ if not track.exists():
+ print(f"File {track} does not exist. If the path contains spaces, "
+ 'please try again after surrounding the entire path with quotes "".',
+ file=sys.stderr)
+ continue
+ print(f"Separating track {track}")
+
+ origin, res = separator.separate_audio_file(track)
+
+ if args.mp3:
+ ext = "mp3"
+ elif args.flac:
+ ext = "flac"
+ else:
+ ext = "wav"
+ kwargs = {
+ "samplerate": separator.samplerate,
+ "bitrate": args.mp3_bitrate,
+ "preset": args.mp3_preset,
+ "clip": args.clip_mode,
+ "as_float": args.float32,
+ "bits_per_sample": 24 if args.int24 else 16,
+ }
+ if args.stem is None:
+ for name, source in res.items():
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem=name,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(source, str(stem), **kwargs)
+ else:
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem="minus_" + args.stem,
+ ext=ext,
+ )
+ if args.other_method == "minus":
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(origin - res[args.stem], str(stem), **kwargs)
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem=args.stem,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(res.pop(args.stem), str(stem), **kwargs)
+ # Warning : after poping the stem, selected stem is no longer in the dict 'res'
+ if args.other_method == "add":
+ other_stem = th.zeros_like(next(iter(res.values())))
+ for i in res.values():
+ other_stem += i
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem="no_" + args.stem,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(other_stem, str(stem), **kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demucs/solver.py b/demucs/solver.py
new file mode 100644
index 00000000..7c80b148
--- /dev/null
+++ b/demucs/solver.py
@@ -0,0 +1,405 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Main training loop."""
+
+import logging
+
+from dora import get_xp
+from dora.utils import write_and_rename
+from dora.log import LogProgress, bold
+import torch
+import torch.nn.functional as F
+
+from . import augment, distrib, states, pretrained
+from .apply import apply_model
+from .ema import ModelEMA
+from .evaluate import evaluate, new_sdr
+from .svd import svd_penalty
+from .utils import pull_metric, EMA
+
+logger = logging.getLogger(__name__)
+
+
+def _summary(metrics):
+ return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items())
+
+
+class Solver(object):
+ def __init__(self, loaders, model, optimizer, args):
+ self.args = args
+ self.loaders = loaders
+
+ self.model = model
+ self.optimizer = optimizer
+ self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer)
+ self.dmodel = distrib.wrap(model)
+ self.device = next(iter(self.model.parameters())).device
+
+ # Exponential moving average of the model, either updated every batch or epoch.
+ # The best model from all the EMAs and the original one is kept based on the valid
+ # loss for the final best model.
+ self.emas = {'batch': [], 'epoch': []}
+ for kind in self.emas.keys():
+ decays = getattr(args.ema, kind)
+ device = self.device if kind == 'batch' else 'cpu'
+ if decays:
+ for decay in decays:
+ self.emas[kind].append(ModelEMA(self.model, decay, device=device))
+
+ # data augment
+ augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift),
+ same=args.augment.shift_same)]
+ if args.augment.flip:
+ augments += [augment.FlipChannels(), augment.FlipSign()]
+ for aug in ['scale', 'remix']:
+ kw = getattr(args.augment, aug)
+ if kw.proba:
+ augments.append(getattr(augment, aug.capitalize())(**kw))
+ self.augment = torch.nn.Sequential(*augments)
+
+ xp = get_xp()
+ self.folder = xp.folder
+ # Checkpoints
+ self.checkpoint_file = xp.folder / 'checkpoint.th'
+ self.best_file = xp.folder / 'best.th'
+ logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve())
+ self.best_state = None
+ self.best_changed = False
+
+ self.link = xp.link
+ self.history = self.link.history
+
+ self._reset()
+
+ def _serialize(self, epoch):
+ package = {}
+ package['state'] = self.model.state_dict()
+ package['optimizer'] = self.optimizer.state_dict()
+ package['history'] = self.history
+ package['best_state'] = self.best_state
+ package['args'] = self.args
+ for kind, emas in self.emas.items():
+ for k, ema in enumerate(emas):
+ package[f'ema_{kind}_{k}'] = ema.state_dict()
+ with write_and_rename(self.checkpoint_file) as tmp:
+ torch.save(package, tmp)
+
+ save_every = self.args.save_every
+ if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs:
+ with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp:
+ torch.save(package, tmp)
+
+ if self.best_changed:
+ # Saving only the latest best model.
+ with write_and_rename(self.best_file) as tmp:
+ package = states.serialize_model(self.model, self.args)
+ package['state'] = self.best_state
+ torch.save(package, tmp)
+ self.best_changed = False
+
+ def _reset(self):
+ """Reset state of the solver, potentially using checkpoint."""
+ if self.checkpoint_file.exists():
+ logger.info(f'Loading checkpoint model: {self.checkpoint_file}')
+ package = torch.load(self.checkpoint_file, 'cpu')
+ self.model.load_state_dict(package['state'])
+ self.optimizer.load_state_dict(package['optimizer'])
+ self.history[:] = package['history']
+ self.best_state = package['best_state']
+ for kind, emas in self.emas.items():
+ for k, ema in enumerate(emas):
+ ema.load_state_dict(package[f'ema_{kind}_{k}'])
+ elif self.args.continue_pretrained:
+ model = pretrained.get_model(
+ name=self.args.continue_pretrained,
+ repo=self.args.pretrained_repo)
+ self.model.load_state_dict(model.state_dict())
+ elif self.args.continue_from:
+ name = 'checkpoint.th'
+ root = self.folder.parent
+ cf = root / str(self.args.continue_from) / name
+ logger.info("Loading from %s", cf)
+ package = torch.load(cf, 'cpu')
+ self.best_state = package['best_state']
+ if self.args.continue_best:
+ self.model.load_state_dict(package['best_state'], strict=False)
+ else:
+ self.model.load_state_dict(package['state'], strict=False)
+ if self.args.continue_opt:
+ self.optimizer.load_state_dict(package['optimizer'])
+
+ def _format_train(self, metrics: dict) -> dict:
+ """Formatting for train/valid metrics."""
+ losses = {
+ 'loss': format(metrics['loss'], ".4f"),
+ 'reco': format(metrics['reco'], ".4f"),
+ }
+ if 'nsdr' in metrics:
+ losses['nsdr'] = format(metrics['nsdr'], ".3f")
+ if self.quantizer is not None:
+ losses['ms'] = format(metrics['ms'], ".2f")
+ if 'grad' in metrics:
+ losses['grad'] = format(metrics['grad'], ".4f")
+ if 'best' in metrics:
+ losses['best'] = format(metrics['best'], '.4f')
+ if 'bname' in metrics:
+ losses['bname'] = metrics['bname']
+ if 'penalty' in metrics:
+ losses['penalty'] = format(metrics['penalty'], ".4f")
+ if 'hloss' in metrics:
+ losses['hloss'] = format(metrics['hloss'], ".4f")
+ return losses
+
+ def _format_test(self, metrics: dict) -> dict:
+ """Formatting for test metrics."""
+ losses = {}
+ if 'sdr' in metrics:
+ losses['sdr'] = format(metrics['sdr'], '.3f')
+ if 'nsdr' in metrics:
+ losses['nsdr'] = format(metrics['nsdr'], '.3f')
+ for source in self.model.sources:
+ key = f'sdr_{source}'
+ if key in metrics:
+ losses[key] = format(metrics[key], '.3f')
+ key = f'nsdr_{source}'
+ if key in metrics:
+ losses[key] = format(metrics[key], '.3f')
+ return losses
+
+ def train(self):
+ # Optimizing the model
+ if self.history:
+ logger.info("Replaying metrics from previous run")
+ for epoch, metrics in enumerate(self.history):
+ formatted = self._format_train(metrics['train'])
+ logger.info(
+ bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
+ formatted = self._format_train(metrics['valid'])
+ logger.info(
+ bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
+ if 'test' in metrics:
+ formatted = self._format_test(metrics['test'])
+ if formatted:
+ logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
+
+ epoch = 0
+ for epoch in range(len(self.history), self.args.epochs):
+ # Train one epoch
+ self.model.train() # Turn on BatchNorm & Dropout
+ metrics = {}
+ logger.info('-' * 70)
+ logger.info("Training...")
+ metrics['train'] = self._run_one_epoch(epoch)
+ formatted = self._format_train(metrics['train'])
+ logger.info(
+ bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
+
+ # Cross validation
+ logger.info('-' * 70)
+ logger.info('Cross validation...')
+ self.model.eval() # Turn off Batchnorm & Dropout
+ with torch.no_grad():
+ valid = self._run_one_epoch(epoch, train=False)
+ bvalid = valid
+ bname = 'main'
+ state = states.copy_state(self.model.state_dict())
+ metrics['valid'] = {}
+ metrics['valid']['main'] = valid
+ key = self.args.test.metric
+ for kind, emas in self.emas.items():
+ for k, ema in enumerate(emas):
+ with ema.swap():
+ valid = self._run_one_epoch(epoch, train=False)
+ name = f'ema_{kind}_{k}'
+ metrics['valid'][name] = valid
+ a = valid[key]
+ b = bvalid[key]
+ if key.startswith('nsdr'):
+ a = -a
+ b = -b
+ if a < b:
+ bvalid = valid
+ state = ema.state
+ bname = name
+ metrics['valid'].update(bvalid)
+ metrics['valid']['bname'] = bname
+
+ valid_loss = metrics['valid'][key]
+ mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss]
+ if key.startswith('nsdr'):
+ best_loss = max(mets)
+ else:
+ best_loss = min(mets)
+ metrics['valid']['best'] = best_loss
+ if self.args.svd.penalty > 0:
+ kw = dict(self.args.svd)
+ kw.pop('penalty')
+ with torch.no_grad():
+ penalty = svd_penalty(self.model, exact=True, **kw)
+ metrics['valid']['penalty'] = penalty
+
+ formatted = self._format_train(metrics['valid'])
+ logger.info(
+ bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
+
+ # Save the best model
+ if valid_loss == best_loss or self.args.dset.train_valid:
+ logger.info(bold('New best valid loss %.4f'), valid_loss)
+ self.best_state = states.copy_state(state)
+ self.best_changed = True
+
+ # Eval model every `test.every` epoch or on last epoch
+ should_eval = (epoch + 1) % self.args.test.every == 0
+ is_last = epoch == self.args.epochs - 1
+ # # Tries to detect divergence in a reliable way and finish job
+ # # not to waste compute.
+ # # Commented out as this was super specific to the MDX competition.
+ # reco = metrics['valid']['main']['reco']
+ # div = epoch >= 180 and reco > 0.18
+ # div = div or epoch >= 100 and reco > 0.25
+ # div = div and self.args.optim.loss == 'l1'
+ # if div:
+ # logger.warning("Finishing training early because valid loss is too high.")
+ # is_last = True
+ if should_eval or is_last:
+ # Evaluate on the testset
+ logger.info('-' * 70)
+ logger.info('Evaluating on the test set...')
+ # We switch to the best known model for testing
+ if self.args.test.best:
+ state = self.best_state
+ else:
+ state = states.copy_state(self.model.state_dict())
+ compute_sdr = self.args.test.sdr and is_last
+ with states.swap_state(self.model, state):
+ with torch.no_grad():
+ metrics['test'] = evaluate(self, compute_sdr=compute_sdr)
+ formatted = self._format_test(metrics['test'])
+ logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
+ self.link.push_metrics(metrics)
+
+ if distrib.rank == 0:
+ # Save model each epoch
+ self._serialize(epoch)
+ logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())
+ if is_last:
+ break
+
+ def _run_one_epoch(self, epoch, train=True):
+ args = self.args
+ data_loader = self.loaders['train'] if train else self.loaders['valid']
+ if distrib.world_size > 1 and train:
+ data_loader.sampler.set_epoch(epoch)
+
+ label = ["Valid", "Train"][train]
+ name = label + f" | Epoch {epoch + 1}"
+ total = len(data_loader)
+ if args.max_batches:
+ total = min(total, args.max_batches)
+ logprog = LogProgress(logger, data_loader, total=total,
+ updates=self.args.misc.num_prints, name=name)
+ averager = EMA()
+
+ for idx, sources in enumerate(logprog):
+ sources = sources.to(self.device)
+ if train:
+ sources = self.augment(sources)
+ mix = sources.sum(dim=1)
+ else:
+ mix = sources[:, 0]
+ sources = sources[:, 1:]
+
+ if not train and self.args.valid_apply:
+ estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0)
+ else:
+ estimate = self.dmodel(mix)
+ if train and hasattr(self.model, 'transform_target'):
+ sources = self.model.transform_target(mix, sources)
+ assert estimate.shape == sources.shape, (estimate.shape, sources.shape)
+ dims = tuple(range(2, sources.dim()))
+
+ if args.optim.loss == 'l1':
+ loss = F.l1_loss(estimate, sources, reduction='none')
+ loss = loss.mean(dims).mean(0)
+ reco = loss
+ elif args.optim.loss == 'mse':
+ loss = F.mse_loss(estimate, sources, reduction='none')
+ loss = loss.mean(dims)
+ reco = loss**0.5
+ reco = reco.mean(0)
+ else:
+ raise ValueError(f"Invalid loss {self.args.loss}")
+ weights = torch.tensor(args.weights).to(sources)
+ loss = (loss * weights).sum() / weights.sum()
+
+ ms = 0
+ if self.quantizer is not None:
+ ms = self.quantizer.model_size()
+ if args.quant.diffq:
+ loss += args.quant.diffq * ms
+
+ losses = {}
+ losses['reco'] = (reco * weights).sum() / weights.sum()
+ losses['ms'] = ms
+
+ if not train:
+ nsdrs = new_sdr(sources, estimate.detach()).mean(0)
+ total = 0
+ for source, nsdr, w in zip(self.model.sources, nsdrs, weights):
+ losses[f'nsdr_{source}'] = nsdr
+ total += w * nsdr
+ losses['nsdr'] = total / weights.sum()
+
+ if train and args.svd.penalty > 0:
+ kw = dict(args.svd)
+ kw.pop('penalty')
+ penalty = svd_penalty(self.model, **kw)
+ losses['penalty'] = penalty
+ loss += args.svd.penalty * penalty
+
+ losses['loss'] = loss
+
+ for k, source in enumerate(self.model.sources):
+ losses[f'reco_{source}'] = reco[k]
+
+ # optimize model in training mode
+ if train:
+ loss.backward()
+ grad_norm = 0
+ grads = []
+ for p in self.model.parameters():
+ if p.grad is not None:
+ grad_norm += p.grad.data.norm()**2
+ grads.append(p.grad.data)
+ losses['grad'] = grad_norm ** 0.5
+ if args.optim.clip_grad:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(),
+ args.optim.clip_grad)
+
+ if self.args.flag == 'uns':
+ for n, p in self.model.named_parameters():
+ if p.grad is None:
+ print('no grad', n)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ for ema in self.emas['batch']:
+ ema.update()
+ losses = averager(losses)
+ logs = self._format_train(losses)
+ logprog.update(**logs)
+ # Just in case, clear some memory
+ del loss, estimate, reco, ms
+ if args.max_batches == idx:
+ break
+ if self.args.debug and train:
+ break
+ if self.args.flag == 'debug':
+ break
+ if train:
+ for ema in self.emas['epoch']:
+ ema.update()
+ return distrib.average(losses, idx + 1)
diff --git a/demucs/spec.py b/demucs/spec.py
new file mode 100644
index 00000000..d8f6ee5e
--- /dev/null
+++ b/demucs/spec.py
@@ -0,0 +1,47 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Conveniance wrapper to perform STFT and iSTFT"""
+
+import torch as th
+
+
+def spectro(x, n_fft=512, hop_length=None, pad=0):
+ *other, length = x.shape
+ x = x.reshape(-1, length)
+ is_mps_xpu = x.device.type in ['mps', 'xpu']
+ if is_mps_xpu:
+ x = x.cpu()
+ z = th.stft(x,
+ n_fft * (1 + pad),
+ hop_length or n_fft // 4,
+ window=th.hann_window(n_fft).to(x),
+ win_length=n_fft,
+ normalized=True,
+ center=True,
+ return_complex=True,
+ pad_mode='reflect')
+ _, freqs, frame = z.shape
+ return z.view(*other, freqs, frame)
+
+
+def ispectro(z, hop_length=None, length=None, pad=0):
+ *other, freqs, frames = z.shape
+ n_fft = 2 * freqs - 2
+ z = z.view(-1, freqs, frames)
+ win_length = n_fft // (1 + pad)
+ is_mps_xpu = z.device.type in ['mps', 'xpu']
+ if is_mps_xpu:
+ z = z.cpu()
+ x = th.istft(z,
+ n_fft,
+ hop_length,
+ window=th.hann_window(win_length).to(z.real),
+ win_length=win_length,
+ normalized=True,
+ length=length,
+ center=True)
+ _, length = x.shape
+ return x.view(*other, length)
diff --git a/demucs/states.py b/demucs/states.py
new file mode 100644
index 00000000..361bb419
--- /dev/null
+++ b/demucs/states.py
@@ -0,0 +1,163 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Utilities to save and load models.
+"""
+from contextlib import contextmanager
+
+import functools
+import hashlib
+import inspect
+import io
+from pathlib import Path
+import warnings
+
+from omegaconf import OmegaConf
+from dora.log import fatal
+import torch
+
+
+def _check_diffq():
+ try:
+ import diffq # noqa
+ except ImportError:
+ fatal('Trying to use DiffQ, but diffq is not installed.\n'
+ 'On Windows run: python.exe -m pip install diffq \n'
+ 'On Linux/Mac, run: python3 -m pip install diffq')
+
+
+def get_quantizer(model, args, optimizer=None):
+ """Return the quantizer given the XP quantization args."""
+ quantizer = None
+ if args.diffq:
+ _check_diffq()
+ from diffq import DiffQuantizer
+ quantizer = DiffQuantizer(
+ model, min_size=args.min_size, group_size=args.group_size)
+ if optimizer is not None:
+ quantizer.setup_optimizer(optimizer)
+ elif args.qat:
+ _check_diffq()
+ from diffq import UniformQuantizer
+ quantizer = UniformQuantizer(
+ model, bits=args.qat, min_size=args.min_size)
+ return quantizer
+
+
+def load_model(path_or_package, strict=False):
+ """Load a model from the given serialized model, either given as a dict (already loaded)
+ or a path to a file on disk."""
+ if isinstance(path_or_package, dict):
+ package = path_or_package
+ elif isinstance(path_or_package, (str, Path)):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ path = path_or_package
+ package = torch.load(path, 'cpu')
+ else:
+ raise ValueError(f"Invalid type for {path_or_package}.")
+
+ klass = package["klass"]
+ args = package["args"]
+ kwargs = package["kwargs"]
+
+ if strict:
+ model = klass(*args, **kwargs)
+ else:
+ sig = inspect.signature(klass)
+ for key in list(kwargs):
+ if key not in sig.parameters:
+ warnings.warn("Dropping inexistant parameter " + key)
+ del kwargs[key]
+ model = klass(*args, **kwargs)
+
+ state = package["state"]
+
+ set_state(model, state)
+ return model
+
+
+def get_state(model, quantizer, half=False):
+ """Get the state from a model, potentially with quantization applied.
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
+ but half the state size."""
+ if quantizer is None:
+ dtype = torch.half if half else None
+ state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
+ else:
+ state = quantizer.get_quantized_state()
+ state['__quantized'] = True
+ return state
+
+
+def set_state(model, state, quantizer=None):
+ """Set the state on a given model."""
+ if state.get('__quantized'):
+ if quantizer is not None:
+ quantizer.restore_quantized_state(model, state['quantized'])
+ else:
+ _check_diffq()
+ from diffq import restore_quantized_state
+ restore_quantized_state(model, state)
+ else:
+ model.load_state_dict(state)
+ return state
+
+
+def save_with_checksum(content, path):
+ """Save the given value on disk, along with a sha256 hash.
+ Should be used with the output of either `serialize_model` or `get_state`."""
+ buf = io.BytesIO()
+ torch.save(content, buf)
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
+
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
+ path.write_bytes(buf.getvalue())
+
+
+def serialize_model(model, training_args, quantizer=None, half=True):
+ args, kwargs = model._init_args_kwargs
+ klass = model.__class__
+
+ state = get_state(model, quantizer, half)
+ return {
+ 'klass': klass,
+ 'args': args,
+ 'kwargs': kwargs,
+ 'state': state,
+ 'training_args': OmegaConf.to_container(training_args, resolve=True),
+ }
+
+
+def copy_state(state):
+ return {k: v.cpu().clone() for k, v in state.items()}
+
+
+@contextmanager
+def swap_state(model, state):
+ """
+ Context manager that swaps the state of a model, e.g:
+
+ # model is in old state
+ with swap_state(model, new_state):
+ # model in new state
+ # model back to old state
+ """
+ old_state = copy_state(model.state_dict())
+ model.load_state_dict(state, strict=False)
+ try:
+ yield
+ finally:
+ model.load_state_dict(old_state)
+
+
+def capture_init(init):
+ @functools.wraps(init)
+ def __init__(self, *args, **kwargs):
+ self._init_args_kwargs = (args, kwargs)
+ init(self, *args, **kwargs)
+
+ return __init__
diff --git a/demucs/svd.py b/demucs/svd.py
new file mode 100644
index 00000000..1cbaa82c
--- /dev/null
+++ b/demucs/svd.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Ways to make the model stronger."""
+import random
+import torch
+
+
+def power_iteration(m, niters=1, bs=1):
+ """This is the power method. batch size is used to try multiple starting point in parallel."""
+ assert m.dim() == 2
+ assert m.shape[0] == m.shape[1]
+ dim = m.shape[0]
+ b = torch.randn(dim, bs, device=m.device, dtype=m.dtype)
+
+ for _ in range(niters):
+ n = m.mm(b)
+ norm = n.norm(dim=0, keepdim=True)
+ b = n / (1e-10 + norm)
+
+ return norm.mean()
+
+
+# We need a shared RNG to make sure all the distributed worker will skip the penalty together,
+# as otherwise we wouldn't get any speed up.
+penalty_rng = random.Random(1234)
+
+
+def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True,
+ proba=1, conv_only=False, exact=False, bs=1):
+ """
+ Penalty on the largest singular value for a layer.
+ Args:
+ - model: model to penalize
+ - min_size: minimum size in MB of a layer to penalize.
+ - dim: projection dimension for the svd_lowrank. Higher is better but slower.
+ - niters: number of iterations in the algorithm used by svd_lowrank.
+ - powm: use power method instead of lowrank SVD, my own experience
+ is that it is both slower and less stable.
+ - convtr: when True, differentiate between Conv and Transposed Conv.
+ this is kept for compatibility with older experiments.
+ - proba: probability to apply the penalty.
+ - conv_only: only apply to conv and conv transposed, not LSTM
+ (might not be reliable for other models than Demucs).
+ - exact: use exact SVD (slow but useful at validation).
+ - bs: batch_size for power method.
+ """
+ total = 0
+ if penalty_rng.random() > proba:
+ return 0.
+
+ for m in model.modules():
+ for name, p in m.named_parameters(recurse=False):
+ if p.numel() / 2**18 < min_size:
+ continue
+ if convtr:
+ if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)):
+ if p.dim() in [3, 4]:
+ p = p.transpose(0, 1).contiguous()
+ if p.dim() == 3:
+ p = p.view(len(p), -1)
+ elif p.dim() == 4:
+ p = p.view(len(p), -1)
+ elif p.dim() == 1:
+ continue
+ elif conv_only:
+ continue
+ assert p.dim() == 2, (name, p.shape)
+ if exact:
+ estimate = torch.svd(p, compute_uv=False)[1].pow(2).max()
+ elif powm:
+ a, b = p.shape
+ if a < b:
+ n = p.mm(p.t())
+ else:
+ n = p.t().mm(p)
+ estimate = power_iteration(n, niters, bs)
+ else:
+ estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2)
+ total += estimate
+ return total / proba
diff --git a/demucs/train.py b/demucs/train.py
new file mode 100644
index 00000000..e045b83f
--- /dev/null
+++ b/demucs/train.py
@@ -0,0 +1,252 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Main training script entry point"""
+
+import logging
+import os
+from pathlib import Path
+import sys
+
+from dora import hydra_main
+import hydra
+from hydra.core.global_hydra import GlobalHydra
+from omegaconf import OmegaConf
+from . import audio_legacy
+import torch
+from torch import nn
+import torchaudio
+from torch.utils.data import ConcatDataset
+
+from . import distrib
+from .wav import get_wav_datasets, get_musdb_wav_datasets
+from .demucs import Demucs
+from .hdemucs import HDemucs
+from .htdemucs import HTDemucs
+from .repitch import RepitchedWrapper
+from .solver import Solver
+from .states import capture_init
+from .utils import random_subset
+
+logger = logging.getLogger(__name__)
+
+
+class TorchHDemucsWrapper(nn.Module):
+ """Wrapper around torchaudio HDemucs implementation to provide the proper metadata
+ for model evaluation.
+ See https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html"""
+
+ @capture_init
+ def __init__(self, **kwargs):
+ super().__init__()
+ try:
+ from torchaudio.models import HDemucs as TorchHDemucs
+ except ImportError:
+ raise ImportError("Please upgrade torchaudio for using its implementation of HDemucs")
+ self.samplerate = kwargs.pop('samplerate')
+ self.segment = kwargs.pop('segment')
+ self.sources = kwargs['sources']
+ self.torch_hdemucs = TorchHDemucs(**kwargs)
+
+ def forward(self, mix):
+ return self.torch_hdemucs.forward(mix)
+
+
+def get_model(args):
+ extra = {
+ 'sources': list(args.dset.sources),
+ 'audio_channels': args.dset.channels,
+ 'samplerate': args.dset.samplerate,
+ 'segment': args.model_segment or 4 * args.dset.segment,
+ }
+ klass = {
+ 'demucs': Demucs,
+ 'hdemucs': HDemucs,
+ 'htdemucs': HTDemucs,
+ 'torch_hdemucs': TorchHDemucsWrapper,
+ }[args.model]
+ kw = OmegaConf.to_container(getattr(args, args.model), resolve=True)
+ model = klass(**extra, **kw)
+ return model
+
+
+def get_optimizer(model, args):
+ seen_params = set()
+ other_params = []
+ groups = []
+ for n, module in model.named_modules():
+ if hasattr(module, "make_optim_group"):
+ group = module.make_optim_group()
+ params = set(group["params"])
+ assert params.isdisjoint(seen_params)
+ seen_params |= set(params)
+ groups.append(group)
+ for param in model.parameters():
+ if param not in seen_params:
+ other_params.append(param)
+ groups.insert(0, {"params": other_params})
+ parameters = groups
+ if args.optim.optim == "adam":
+ return torch.optim.Adam(
+ parameters,
+ lr=args.optim.lr,
+ betas=(args.optim.momentum, args.optim.beta2),
+ weight_decay=args.optim.weight_decay,
+ )
+ elif args.optim.optim == "adamw":
+ return torch.optim.AdamW(
+ parameters,
+ lr=args.optim.lr,
+ betas=(args.optim.momentum, args.optim.beta2),
+ weight_decay=args.optim.weight_decay,
+ )
+ else:
+ raise ValueError("Invalid optimizer %s", args.optim.optimizer)
+
+
+def get_datasets(args):
+ if args.dset.backend:
+ torchaudio.set_audio_backend(args.dset.backend)
+ if args.dset.use_musdb:
+ train_set, valid_set = get_musdb_wav_datasets(args.dset)
+ else:
+ train_set, valid_set = [], []
+ if args.dset.wav:
+ extra_train_set, extra_valid_set = get_wav_datasets(args.dset)
+ if len(args.dset.sources) <= 4:
+ train_set = ConcatDataset([train_set, extra_train_set])
+ valid_set = ConcatDataset([valid_set, extra_valid_set])
+ else:
+ train_set = extra_train_set
+ valid_set = extra_valid_set
+
+ if args.dset.wav2:
+ extra_train_set, extra_valid_set = get_wav_datasets(args.dset, "wav2")
+ weight = args.dset.wav2_weight
+ if weight is not None:
+ b = len(train_set)
+ e = len(extra_train_set)
+ reps = max(1, round(e / b * (1 / weight - 1)))
+ else:
+ reps = 1
+ train_set = ConcatDataset([train_set] * reps + [extra_train_set])
+ if args.dset.wav2_valid:
+ if weight is not None:
+ b = len(valid_set)
+ n_kept = int(round(weight * b / (1 - weight)))
+ valid_set = ConcatDataset(
+ [valid_set, random_subset(extra_valid_set, n_kept)]
+ )
+ else:
+ valid_set = ConcatDataset([valid_set, extra_valid_set])
+ if args.dset.valid_samples is not None:
+ valid_set = random_subset(valid_set, args.dset.valid_samples)
+ assert len(train_set)
+ assert len(valid_set)
+ return train_set, valid_set
+
+
+def get_solver(args, model_only=False):
+ distrib.init()
+
+ torch.manual_seed(args.seed)
+ model = get_model(args)
+ if args.misc.show:
+ logger.info(model)
+ mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
+ logger.info('Size: %.1f MB', mb)
+ if hasattr(model, 'valid_length'):
+ field = model.valid_length(1)
+ logger.info('Field: %.1f ms', field / args.dset.samplerate * 1000)
+ sys.exit(0)
+
+ # torch also initialize cuda seed if available
+ if torch.cuda.is_available():
+ model.cuda()
+
+ # optimizer
+ optimizer = get_optimizer(model, args)
+
+ assert args.batch_size % distrib.world_size == 0
+ args.batch_size //= distrib.world_size
+
+ if model_only:
+ return Solver(None, model, optimizer, args)
+
+ train_set, valid_set = get_datasets(args)
+
+ if args.augment.repitch.proba:
+ vocals = []
+ if 'vocals' in args.dset.sources:
+ vocals.append(args.dset.sources.index('vocals'))
+ else:
+ logger.warning('No vocal source found')
+ if args.augment.repitch.proba:
+ train_set = RepitchedWrapper(train_set, vocals=vocals, **args.augment.repitch)
+
+ logger.info("train/valid set size: %d %d", len(train_set), len(valid_set))
+ train_loader = distrib.loader(
+ train_set, batch_size=args.batch_size, shuffle=True,
+ num_workers=args.misc.num_workers, drop_last=True)
+ if args.dset.full_cv:
+ valid_loader = distrib.loader(
+ valid_set, batch_size=1, shuffle=False,
+ num_workers=args.misc.num_workers)
+ else:
+ valid_loader = distrib.loader(
+ valid_set, batch_size=args.batch_size, shuffle=False,
+ num_workers=args.misc.num_workers, drop_last=True)
+ loaders = {"train": train_loader, "valid": valid_loader}
+
+ # Construct Solver
+ return Solver(loaders, model, optimizer, args)
+
+
+def get_solver_from_sig(sig, model_only=False):
+ inst = GlobalHydra.instance()
+ hyd = None
+ if inst.is_initialized():
+ hyd = inst.hydra
+ inst.clear()
+ xp = main.get_xp_from_sig(sig)
+ if hyd is not None:
+ inst.clear()
+ inst.initialize(hyd)
+
+ with xp.enter(stack=True):
+ return get_solver(xp.cfg, model_only)
+
+
+@hydra_main(config_path="../conf", config_name="config", version_base="1.1")
+def main(args):
+ global __file__
+ __file__ = hydra.utils.to_absolute_path(__file__)
+ for attr in ["musdb", "wav", "metadata"]:
+ val = getattr(args.dset, attr)
+ if val is not None:
+ setattr(args.dset, attr, hydra.utils.to_absolute_path(val))
+
+ os.environ["OMP_NUM_THREADS"] = "1"
+ os.environ["MKL_NUM_THREADS"] = "1"
+
+ if args.misc.verbose:
+ logger.setLevel(logging.DEBUG)
+
+ logger.info("For logs, checkpoints and samples check %s", os.getcwd())
+ logger.debug(args)
+ from dora import get_xp
+ logger.debug(get_xp().cfg)
+
+ solver = get_solver(args)
+ solver.train()
+
+
+if '_DORA_TEST_PATH' in os.environ:
+ main.dora.dir = Path(os.environ['_DORA_TEST_PATH'])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demucs/transformer.py b/demucs/transformer.py
new file mode 100644
index 00000000..56a465b8
--- /dev/null
+++ b/demucs/transformer.py
@@ -0,0 +1,839 @@
+# Copyright (c) 2019-present, Meta, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# First author is Simon Rouard.
+
+import random
+import typing as tp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import math
+from einops import rearrange
+
+
+def create_sin_embedding(
+ length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
+):
+ # We aim for TBC format
+ assert dim % 2 == 0
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
+ half_dim = dim // 2
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
+ return torch.cat(
+ [
+ torch.cos(phase),
+ torch.sin(phase),
+ ],
+ dim=-1,
+ )
+
+
+def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
+ """
+ :param d_model: dimension of the model
+ :param height: height of the positions
+ :param width: width of the positions
+ :return: d_model*height*width position matrix
+ """
+ if d_model % 4 != 0:
+ raise ValueError(
+ "Cannot use sin/cos positional encoding with "
+ "odd dimension (got dim={:d})".format(d_model)
+ )
+ pe = torch.zeros(d_model, height, width)
+ # Each dimension use half of d_model
+ d_model = int(d_model / 2)
+ div_term = torch.exp(
+ torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
+ )
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
+ pe[0:d_model:2, :, :] = (
+ torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
+ )
+ pe[1:d_model:2, :, :] = (
+ torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
+ )
+ pe[d_model::2, :, :] = (
+ torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
+ )
+ pe[d_model + 1:: 2, :, :] = (
+ torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
+ )
+
+ return pe[None, :].to(device)
+
+
+def create_sin_embedding_cape(
+ length: int,
+ dim: int,
+ batch_size: int,
+ mean_normalize: bool,
+ augment: bool, # True during training
+ max_global_shift: float = 0.0, # delta max
+ max_local_shift: float = 0.0, # epsilon max
+ max_scale: float = 1.0,
+ device: str = "cpu",
+ max_period: float = 10000.0,
+):
+ # We aim for TBC format
+ assert dim % 2 == 0
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
+ if mean_normalize:
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
+
+ if augment:
+ delta = np.random.uniform(
+ -max_global_shift, +max_global_shift, size=[1, batch_size, 1]
+ )
+ delta_local = np.random.uniform(
+ -max_local_shift, +max_local_shift, size=[length, batch_size, 1]
+ )
+ log_lambdas = np.random.uniform(
+ -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
+ )
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
+
+ pos = pos.to(device)
+
+ half_dim = dim // 2
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
+ return torch.cat(
+ [
+ torch.cos(phase),
+ torch.sin(phase),
+ ],
+ dim=-1,
+ ).float()
+
+
+def get_causal_mask(length):
+ pos = torch.arange(length)
+ return pos > pos[:, None]
+
+
+def get_elementary_mask(
+ T1,
+ T2,
+ mask_type,
+ sparse_attn_window,
+ global_window,
+ mask_random_seed,
+ sparsity,
+ device,
+):
+ """
+ When the input of the Decoder has length T1 and the output T2
+ The mask matrix has shape (T2, T1)
+ """
+ assert mask_type in ["diag", "jmask", "random", "global"]
+
+ if mask_type == "global":
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
+ mask[:, :global_window] = True
+ line_window = int(global_window * T2 / T1)
+ mask[:line_window, :] = True
+
+ if mask_type == "diag":
+
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
+ rows = torch.arange(T2)[:, None]
+ cols = (
+ (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
+ .long()
+ .clamp(0, T1 - 1)
+ )
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
+
+ elif mask_type == "jmask":
+ mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
+ rows = torch.arange(T2 + 2)[:, None]
+ t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
+ t = (t * (t + 1) / 2).int()
+ t = torch.cat([-t.flip(0)[:-1], t])
+ cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
+ mask = mask[1:-1, 1:-1]
+
+ elif mask_type == "random":
+ gene = torch.Generator(device=device)
+ gene.manual_seed(mask_random_seed)
+ mask = (
+ torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
+ > sparsity
+ )
+
+ mask = mask.to(device)
+ return mask
+
+
+def get_mask(
+ T1,
+ T2,
+ mask_type,
+ sparse_attn_window,
+ global_window,
+ mask_random_seed,
+ sparsity,
+ device,
+):
+ """
+ Return a SparseCSRTensor mask that is a combination of elementary masks
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
+ """
+ from xformers.sparse import SparseCSRTensor
+ # create a list
+ mask_types = mask_type.split("_")
+
+ all_masks = [
+ get_elementary_mask(
+ T1,
+ T2,
+ mask,
+ sparse_attn_window,
+ global_window,
+ mask_random_seed,
+ sparsity,
+ device,
+ )
+ for mask in mask_types
+ ]
+
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
+
+ return SparseCSRTensor.from_dense(final_mask[None])
+
+
+class ScaledEmbedding(nn.Module):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ scale: float = 1.0,
+ boost: float = 3.0,
+ ):
+ super().__init__()
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
+ self.embedding.weight.data *= scale / boost
+ self.boost = boost
+
+ @property
+ def weight(self):
+ return self.embedding.weight * self.boost
+
+ def forward(self, x):
+ return self.embedding(x) * self.boost
+
+
+class LayerScale(nn.Module):
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
+ """
+
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
+ """
+ channel_last = False corresponds to (B, C, T) tensors
+ channel_last = True corresponds to (T, B, C) tensors
+ """
+ super().__init__()
+ self.channel_last = channel_last
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
+ self.scale.data[:] = init
+
+ def forward(self, x):
+ if self.channel_last:
+ return self.scale * x
+ else:
+ return self.scale[:, None] * x
+
+
+class MyGroupNorm(nn.GroupNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ """
+ x: (B, T, C)
+ if num_groups=1: Normalisation on all T and C together for each B
+ """
+ x = x.transpose(1, 2)
+ return super().forward(x).transpose(1, 2)
+
+
+class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation=F.relu,
+ group_norm=0,
+ norm_first=False,
+ norm_out=False,
+ layer_norm_eps=1e-5,
+ layer_scale=False,
+ init_values=1e-4,
+ device=None,
+ dtype=None,
+ sparse=False,
+ mask_type="diag",
+ mask_random_seed=42,
+ sparse_attn_window=500,
+ global_window=50,
+ auto_sparsity=False,
+ sparsity=0.95,
+ batch_first=False,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ activation=activation,
+ layer_norm_eps=layer_norm_eps,
+ batch_first=batch_first,
+ norm_first=norm_first,
+ device=device,
+ dtype=dtype,
+ )
+ self.sparse = sparse
+ self.auto_sparsity = auto_sparsity
+ if sparse:
+ if not auto_sparsity:
+ self.mask_type = mask_type
+ self.sparse_attn_window = sparse_attn_window
+ self.global_window = global_window
+ self.sparsity = sparsity
+ if group_norm:
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+
+ self.norm_out = None
+ if self.norm_first & norm_out:
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
+ self.gamma_1 = (
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
+ )
+ self.gamma_2 = (
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
+ )
+
+ if sparse:
+ self.self_attn = MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
+ auto_sparsity=sparsity if auto_sparsity else 0,
+ )
+ self.__setattr__("src_mask", torch.zeros(1, 1))
+ self.mask_random_seed = mask_random_seed
+
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
+ """
+ if batch_first = False, src shape is (T, B, C)
+ the case where batch_first=True is not covered
+ """
+ device = src.device
+ x = src
+ T, B, C = x.shape
+ if self.sparse and not self.auto_sparsity:
+ assert src_mask is None
+ src_mask = self.src_mask
+ if src_mask.shape[-1] != T:
+ src_mask = get_mask(
+ T,
+ T,
+ self.mask_type,
+ self.sparse_attn_window,
+ self.global_window,
+ self.mask_random_seed,
+ self.sparsity,
+ device,
+ )
+ self.__setattr__("src_mask", src_mask)
+
+ if self.norm_first:
+ x = x + self.gamma_1(
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
+ )
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
+
+ if self.norm_out:
+ x = self.norm_out(x)
+ else:
+ x = self.norm1(
+ x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
+ )
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
+
+ return x
+
+
+class CrossTransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation=F.relu,
+ layer_norm_eps: float = 1e-5,
+ layer_scale: bool = False,
+ init_values: float = 1e-4,
+ norm_first: bool = False,
+ group_norm: bool = False,
+ norm_out: bool = False,
+ sparse=False,
+ mask_type="diag",
+ mask_random_seed=42,
+ sparse_attn_window=500,
+ global_window=50,
+ sparsity=0.95,
+ auto_sparsity=None,
+ device=None,
+ dtype=None,
+ batch_first=False,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ self.sparse = sparse
+ self.auto_sparsity = auto_sparsity
+ if sparse:
+ if not auto_sparsity:
+ self.mask_type = mask_type
+ self.sparse_attn_window = sparse_attn_window
+ self.global_window = global_window
+ self.sparsity = sparsity
+
+ self.cross_attn: nn.Module
+ self.cross_attn = nn.MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
+
+ self.norm_first = norm_first
+ self.norm1: nn.Module
+ self.norm2: nn.Module
+ self.norm3: nn.Module
+ if group_norm:
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+ else:
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+
+ self.norm_out = None
+ if self.norm_first & norm_out:
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
+
+ self.gamma_1 = (
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
+ )
+ self.gamma_2 = (
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ # Legacy string support for activation function.
+ if isinstance(activation, str):
+ self.activation = self._get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ if sparse:
+ self.cross_attn = MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
+ auto_sparsity=sparsity if auto_sparsity else 0)
+ if not auto_sparsity:
+ self.__setattr__("mask", torch.zeros(1, 1))
+ self.mask_random_seed = mask_random_seed
+
+ def forward(self, q, k, mask=None):
+ """
+ Args:
+ q: tensor of shape (T, B, C)
+ k: tensor of shape (S, B, C)
+ mask: tensor of shape (T, S)
+
+ """
+ device = q.device
+ T, B, C = q.shape
+ S, B, C = k.shape
+ if self.sparse and not self.auto_sparsity:
+ assert mask is None
+ mask = self.mask
+ if mask.shape[-1] != S or mask.shape[-2] != T:
+ mask = get_mask(
+ S,
+ T,
+ self.mask_type,
+ self.sparse_attn_window,
+ self.global_window,
+ self.mask_random_seed,
+ self.sparsity,
+ device,
+ )
+ self.__setattr__("mask", mask)
+
+ if self.norm_first:
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
+ if self.norm_out:
+ x = self.norm_out(x)
+ else:
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
+
+ return x
+
+ # self-attention block
+ def _ca_block(self, q, k, attn_mask=None):
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
+ return self.dropout1(x)
+
+ # feed forward block
+ def _ff_block(self, x):
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ return self.dropout2(x)
+
+ def _get_activation_fn(self, activation):
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return F.gelu
+
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+
+
+# ----------------- MULTI-BLOCKS MODELS: -----------------------
+
+
+class CrossTransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ emb: str = "sin",
+ hidden_scale: float = 4.0,
+ num_heads: int = 8,
+ num_layers: int = 6,
+ cross_first: bool = False,
+ dropout: float = 0.0,
+ max_positions: int = 1000,
+ norm_in: bool = True,
+ norm_in_group: bool = False,
+ group_norm: int = False,
+ norm_first: bool = False,
+ norm_out: bool = False,
+ max_period: float = 10000.0,
+ weight_decay: float = 0.0,
+ lr: tp.Optional[float] = None,
+ layer_scale: bool = False,
+ gelu: bool = True,
+ sin_random_shift: int = 0,
+ weight_pos_embed: float = 1.0,
+ cape_mean_normalize: bool = True,
+ cape_augment: bool = True,
+ cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
+ sparse_self_attn: bool = False,
+ sparse_cross_attn: bool = False,
+ mask_type: str = "diag",
+ mask_random_seed: int = 42,
+ sparse_attn_window: int = 500,
+ global_window: int = 50,
+ auto_sparsity: bool = False,
+ sparsity: float = 0.95,
+ ):
+ super().__init__()
+ """
+ """
+ assert dim % num_heads == 0
+
+ hidden_dim = int(dim * hidden_scale)
+
+ self.num_layers = num_layers
+ # classic parity = 1 means that if idx%2 == 1 there is a
+ # classical encoder else there is a cross encoder
+ self.classic_parity = 1 if cross_first else 0
+ self.emb = emb
+ self.max_period = max_period
+ self.weight_decay = weight_decay
+ self.weight_pos_embed = weight_pos_embed
+ self.sin_random_shift = sin_random_shift
+ if emb == "cape":
+ self.cape_mean_normalize = cape_mean_normalize
+ self.cape_augment = cape_augment
+ self.cape_glob_loc_scale = cape_glob_loc_scale
+ if emb == "scaled":
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
+
+ self.lr = lr
+
+ activation: tp.Any = F.gelu if gelu else F.relu
+
+ self.norm_in: nn.Module
+ self.norm_in_t: nn.Module
+ if norm_in:
+ self.norm_in = nn.LayerNorm(dim)
+ self.norm_in_t = nn.LayerNorm(dim)
+ elif norm_in_group:
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
+ else:
+ self.norm_in = nn.Identity()
+ self.norm_in_t = nn.Identity()
+
+ # spectrogram layers
+ self.layers = nn.ModuleList()
+ # temporal layers
+ self.layers_t = nn.ModuleList()
+
+ kwargs_common = {
+ "d_model": dim,
+ "nhead": num_heads,
+ "dim_feedforward": hidden_dim,
+ "dropout": dropout,
+ "activation": activation,
+ "group_norm": group_norm,
+ "norm_first": norm_first,
+ "norm_out": norm_out,
+ "layer_scale": layer_scale,
+ "mask_type": mask_type,
+ "mask_random_seed": mask_random_seed,
+ "sparse_attn_window": sparse_attn_window,
+ "global_window": global_window,
+ "sparsity": sparsity,
+ "auto_sparsity": auto_sparsity,
+ "batch_first": True,
+ }
+
+ kwargs_classic_encoder = dict(kwargs_common)
+ kwargs_classic_encoder.update({
+ "sparse": sparse_self_attn,
+ })
+ kwargs_cross_encoder = dict(kwargs_common)
+ kwargs_cross_encoder.update({
+ "sparse": sparse_cross_attn,
+ })
+
+ for idx in range(num_layers):
+ if idx % 2 == self.classic_parity:
+
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
+ self.layers_t.append(
+ MyTransformerEncoderLayer(**kwargs_classic_encoder)
+ )
+
+ else:
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
+
+ self.layers_t.append(
+ CrossTransformerEncoderLayer(**kwargs_cross_encoder)
+ )
+
+ def forward(self, x, xt):
+ B, C, Fr, T1 = x.shape
+ pos_emb_2d = create_2d_sin_embedding(
+ C, Fr, T1, x.device, self.max_period
+ ) # (1, C, Fr, T1)
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
+ x = self.norm_in(x)
+ x = x + self.weight_pos_embed * pos_emb_2d
+
+ B, C, T2 = xt.shape
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
+ pos_emb = self._get_pos_embedding(T2, B, C, x.device)
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
+ xt = self.norm_in_t(xt)
+ xt = xt + self.weight_pos_embed * pos_emb
+
+ for idx in range(self.num_layers):
+ if idx % 2 == self.classic_parity:
+ x = self.layers[idx](x)
+ xt = self.layers_t[idx](xt)
+ else:
+ old_x = x
+ x = self.layers[idx](x, xt)
+ xt = self.layers_t[idx](xt, old_x)
+
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
+ xt = rearrange(xt, "b t2 c -> b c t2")
+ return x, xt
+
+ def _get_pos_embedding(self, T, B, C, device):
+ if self.emb == "sin":
+ shift = random.randrange(self.sin_random_shift + 1)
+ pos_emb = create_sin_embedding(
+ T, C, shift=shift, device=device, max_period=self.max_period
+ )
+ elif self.emb == "cape":
+ if self.training:
+ pos_emb = create_sin_embedding_cape(
+ T,
+ C,
+ B,
+ device=device,
+ max_period=self.max_period,
+ mean_normalize=self.cape_mean_normalize,
+ augment=self.cape_augment,
+ max_global_shift=self.cape_glob_loc_scale[0],
+ max_local_shift=self.cape_glob_loc_scale[1],
+ max_scale=self.cape_glob_loc_scale[2],
+ )
+ else:
+ pos_emb = create_sin_embedding_cape(
+ T,
+ C,
+ B,
+ device=device,
+ max_period=self.max_period,
+ mean_normalize=self.cape_mean_normalize,
+ augment=False,
+ )
+
+ elif self.emb == "scaled":
+ pos = torch.arange(T, device=device)
+ pos_emb = self.position_embeddings(pos)[:, None]
+
+ return pos_emb
+
+ def make_optim_group(self):
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
+ if self.lr is not None:
+ group["lr"] = self.lr
+ return group
+
+
+# Attention Modules
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ kdim=None,
+ vdim=None,
+ batch_first=False,
+ auto_sparsity=None,
+ ):
+ super().__init__()
+ assert auto_sparsity is not None, "sanity check"
+ self.num_heads = num_heads
+ self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.attn_drop = torch.nn.Dropout(dropout)
+ self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
+ self.proj_drop = torch.nn.Dropout(dropout)
+ self.batch_first = batch_first
+ self.auto_sparsity = auto_sparsity
+
+ def forward(
+ self,
+ query,
+ key,
+ value,
+ key_padding_mask=None,
+ need_weights=True,
+ attn_mask=None,
+ average_attn_weights=True,
+ ):
+
+ if not self.batch_first: # N, B, C
+ query = query.permute(1, 0, 2) # B, N_q, C
+ key = key.permute(1, 0, 2) # B, N_k, C
+ value = value.permute(1, 0, 2) # B, N_k, C
+ B, N_q, C = query.shape
+ B, N_k, C = key.shape
+
+ q = (
+ self.q(query)
+ .reshape(B, N_q, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ q = q.flatten(0, 1)
+ k = (
+ self.k(key)
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ k = k.flatten(0, 1)
+ v = (
+ self.v(value)
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ v = v.flatten(0, 1)
+
+ if self.auto_sparsity:
+ assert attn_mask is None
+ x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
+ else:
+ x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
+ x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
+
+ x = x.transpose(1, 2).reshape(B, N_q, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ if not self.batch_first:
+ x = x.permute(1, 0, 2)
+ return x, None
+
+
+def scaled_query_key_softmax(q, k, att_mask):
+ from xformers.ops import masked_matmul
+ q = q / (k.size(-1)) ** 0.5
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
+ att = torch.nn.functional.softmax(att, -1)
+ return att
+
+
+def scaled_dot_product_attention(q, k, v, att_mask, dropout):
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
+ att = dropout(att)
+ y = att @ v
+ return y
+
+
+def _compute_buckets(x, R):
+ qq = torch.einsum('btf,bfhi->bhti', x, R)
+ qq = torch.cat([qq, -qq], dim=-1)
+ buckets = qq.argmax(dim=-1)
+
+ return buckets.permute(0, 2, 1).byte().contiguous()
+
+
+def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
+ n_hashes = 32
+ proj_size = 4
+ query, key, value = [x.contiguous() for x in [query, key, value]]
+ with torch.no_grad():
+ R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
+ bucket_query = _compute_buckets(query, R)
+ bucket_key = _compute_buckets(key, R)
+ row_offsets, column_indices = find_locations(
+ bucket_query, bucket_key, sparsity, infer_sparsity)
+ return sparse_memory_efficient_attention(
+ query, key, value, row_offsets, column_indices, attn_bias)
diff --git a/demucs/utils.py b/demucs/utils.py
new file mode 100755
index 00000000..a3f5993e
--- /dev/null
+++ b/demucs/utils.py
@@ -0,0 +1,149 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+from concurrent.futures import CancelledError
+from contextlib import contextmanager
+import math
+import os
+import tempfile
+import typing as tp
+
+import torch
+from torch.nn import functional as F
+from torch.utils.data import Subset
+
+
+def unfold(a, kernel_size, stride):
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
+ with K the kernel size, by extracting frames with the given stride.
+
+ This will pad the input so that `F = ceil(T / K)`.
+
+ see https://github.com/pytorch/pytorch/issues/60466
+ """
+ *shape, length = a.shape
+ n_frames = math.ceil(length / stride)
+ tgt_length = (n_frames - 1) * stride + kernel_size
+ a = F.pad(a, (0, tgt_length - length))
+ strides = list(a.stride())
+ assert strides[-1] == 1, 'data should be contiguous'
+ strides = strides[:-1] + [stride, 1]
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
+
+
+def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
+ """
+ Center trim `tensor` with respect to `reference`, along the last dimension.
+ `reference` can also be a number, representing the length to trim to.
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
+ """
+ ref_size: int
+ if isinstance(reference, torch.Tensor):
+ ref_size = reference.size(-1)
+ else:
+ ref_size = reference
+ delta = tensor.size(-1) - ref_size
+ if delta < 0:
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
+ if delta:
+ tensor = tensor[..., delta // 2:-(delta - delta // 2)]
+ return tensor
+
+
+def pull_metric(history: tp.List[dict], name: str):
+ out = []
+ for metrics in history:
+ metric = metrics
+ for part in name.split("."):
+ metric = metric[part]
+ out.append(metric)
+ return out
+
+
+def EMA(beta: float = 1):
+ """
+ Exponential Moving Average callback.
+ Returns a single function that can be called to repeatidly update the EMA
+ with a dict of metrics. The callback will return
+ the new averaged dict of metrics.
+
+ Note that for `beta=1`, this is just plain averaging.
+ """
+ fix: tp.Dict[str, float] = defaultdict(float)
+ total: tp.Dict[str, float] = defaultdict(float)
+
+ def _update(metrics: dict, weight: float = 1) -> dict:
+ nonlocal total, fix
+ for key, value in metrics.items():
+ total[key] = total[key] * beta + weight * float(value)
+ fix[key] = fix[key] * beta + weight
+ return {key: tot / fix[key] for key, tot in total.items()}
+ return _update
+
+
+def sizeof_fmt(num: float, suffix: str = 'B'):
+ """
+ Given `num` bytes, return human readable size.
+ Taken from https://stackoverflow.com/a/1094933
+ """
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
+ if abs(num) < 1024.0:
+ return "%3.1f%s%s" % (num, unit, suffix)
+ num /= 1024.0
+ return "%.1f%s%s" % (num, 'Yi', suffix)
+
+
+@contextmanager
+def temp_filenames(count: int, delete=True):
+ names = []
+ try:
+ for _ in range(count):
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
+ yield names
+ finally:
+ if delete:
+ for name in names:
+ os.unlink(name)
+
+
+def random_subset(dataset, max_samples: int, seed: int = 42):
+ if max_samples >= len(dataset):
+ return dataset
+
+ generator = torch.Generator().manual_seed(seed)
+ perm = torch.randperm(len(dataset), generator=generator)
+ return Subset(dataset, perm[:max_samples].tolist())
+
+
+class DummyPoolExecutor:
+ class DummyResult:
+ def __init__(self, func, _dict, *args, **kwargs):
+ self.func = func
+ self._dict = _dict
+ self.args = args
+ self.kwargs = kwargs
+
+ def result(self):
+ if self._dict["run"]:
+ return self.func(*self.args, **self.kwargs)
+ else:
+ raise CancelledError()
+
+ def __init__(self, workers=0):
+ self._dict = {"run": True}
+
+ def submit(self, func, *args, **kwargs):
+ return DummyPoolExecutor.DummyResult(func, self._dict, *args, **kwargs)
+
+ def shutdown(self, *_, **__):
+ self._dict["run"] = False
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ return
diff --git a/demucs/wav.py b/demucs/wav.py
new file mode 100644
index 00000000..ca1e23a3
--- /dev/null
+++ b/demucs/wav.py
@@ -0,0 +1,255 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Loading wav based datasets, including MusdbHQ."""
+
+from collections import OrderedDict
+import hashlib
+import math
+import json
+import os
+from pathlib import Path
+import tqdm
+
+import musdb
+import julius
+from . import audio_legacy
+import torch as th
+from torch import distributed
+import torchaudio as ta
+from torch.nn import functional as F
+
+from .audio import convert_audio_channels
+from . import distrib
+
+MIXTURE = "mixture"
+EXT = ".wav"
+
+
+def _track_metadata(track, sources, normalize=True, ext=EXT):
+ track_length = None
+ track_samplerate = None
+ mean = 0
+ std = 1
+ for source in sources + [MIXTURE]:
+ file = track / f"{source}{ext}"
+ if source == MIXTURE and not file.exists():
+ audio = 0
+ for sub_source in sources:
+ sub_file = track / f"{sub_source}{ext}"
+ sub_audio, sr = ta.load(sub_file)
+ audio += sub_audio
+ would_clip = audio.abs().max() >= 1
+ if would_clip:
+ assert ta.get_audio_backend() == 'soundfile', 'use dset.backend=soundfile'
+ ta.save(file, audio, sr, encoding='PCM_F')
+
+ try:
+ info = ta.info(str(file))
+ except RuntimeError:
+ print(file)
+ raise
+ length = info.num_frames
+ if track_length is None:
+ track_length = length
+ track_samplerate = info.sample_rate
+ elif track_length != length:
+ raise ValueError(
+ f"Invalid length for file {file}: "
+ f"expecting {track_length} but got {length}.")
+ elif info.sample_rate != track_samplerate:
+ raise ValueError(
+ f"Invalid sample rate for file {file}: "
+ f"expecting {track_samplerate} but got {info.sample_rate}.")
+ if source == MIXTURE and normalize:
+ try:
+ wav, _ = ta.load(str(file))
+ except RuntimeError:
+ print(file)
+ raise
+ wav = wav.mean(0)
+ mean = wav.mean().item()
+ std = wav.std().item()
+
+ return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}
+
+
+def build_metadata(path, sources, normalize=True, ext=EXT):
+ """
+ Build the metadata for `Wavset`.
+
+ Args:
+ path (str or Path): path to dataset.
+ sources (list[str]): list of sources to look for.
+ normalize (bool): if True, loads full track and store normalization
+ values based on the mixture file.
+ ext (str): extension of audio files (default is .wav).
+ """
+
+ meta = {}
+ path = Path(path)
+ pendings = []
+ from concurrent.futures import ThreadPoolExecutor
+ with ThreadPoolExecutor(8) as pool:
+ for root, folders, files in os.walk(path, followlinks=True):
+ root = Path(root)
+ if root.name.startswith('.') or folders or root == path:
+ continue
+ name = str(root.relative_to(path))
+ pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext)))
+ # meta[name] = _track_metadata(root, sources, normalize, ext)
+ for name, pending in tqdm.tqdm(pendings, ncols=120):
+ meta[name] = pending.result()
+ return meta
+
+
+class Wavset:
+ def __init__(
+ self,
+ root, metadata, sources,
+ segment=None, shift=None, normalize=True,
+ samplerate=44100, channels=2, ext=EXT):
+ """
+ Waveset (or mp3 set for that matter). Can be used to train
+ with arbitrary sources. Each track should be one folder inside of `path`.
+ The folder should contain files named `{source}.{ext}`.
+
+ Args:
+ root (Path or str): root folder for the dataset.
+ metadata (dict): output from `build_metadata`.
+ sources (list[str]): list of source names.
+ segment (None or float): segment length in seconds. If `None`, returns entire tracks.
+ shift (None or float): stride in seconds bewteen samples.
+ normalize (bool): normalizes input audio, **based on the metadata content**,
+ i.e. the entire track is normalized, not individual extracts.
+ samplerate (int): target sample rate. if the file sample rate
+ is different, it will be resampled on the fly.
+ channels (int): target nb of channels. if different, will be
+ changed onthe fly.
+ ext (str): extension for audio files (default is .wav).
+
+ samplerate and channels are converted on the fly.
+ """
+ self.root = Path(root)
+ self.metadata = OrderedDict(metadata)
+ self.segment = segment
+ self.shift = shift or segment
+ self.normalize = normalize
+ self.sources = sources
+ self.channels = channels
+ self.samplerate = samplerate
+ self.ext = ext
+ self.num_examples = []
+ for name, meta in self.metadata.items():
+ track_duration = meta['length'] / meta['samplerate']
+ if segment is None or track_duration < segment:
+ examples = 1
+ else:
+ examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1)
+ self.num_examples.append(examples)
+
+ def __len__(self):
+ return sum(self.num_examples)
+
+ def get_file(self, name, source):
+ return self.root / name / f"{source}{self.ext}"
+
+ def __getitem__(self, index):
+ for name, examples in zip(self.metadata, self.num_examples):
+ if index >= examples:
+ index -= examples
+ continue
+ meta = self.metadata[name]
+ num_frames = -1
+ offset = 0
+ if self.segment is not None:
+ offset = int(meta['samplerate'] * self.shift * index)
+ num_frames = int(math.ceil(meta['samplerate'] * self.segment))
+ wavs = []
+ for source in self.sources:
+ file = self.get_file(name, source)
+ wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
+ wav = convert_audio_channels(wav, self.channels)
+ wavs.append(wav)
+
+ example = th.stack(wavs)
+ example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
+ if self.normalize:
+ example = (example - meta['mean']) / meta['std']
+ if self.segment:
+ length = int(self.segment * self.samplerate)
+ example = example[..., :length]
+ example = F.pad(example, (0, length - example.shape[-1]))
+ return example
+
+
+def get_wav_datasets(args, name='wav'):
+ """Extract the wav datasets from the XP arguments."""
+ path = getattr(args, name)
+ sig = hashlib.sha1(str(path).encode()).hexdigest()[:8]
+ metadata_file = Path(args.metadata) / ('wav_' + sig + ".json")
+ train_path = Path(path) / "train"
+ valid_path = Path(path) / "valid"
+ if not metadata_file.is_file() and distrib.rank == 0:
+ metadata_file.parent.mkdir(exist_ok=True, parents=True)
+ train = build_metadata(train_path, args.sources)
+ valid = build_metadata(valid_path, args.sources)
+ json.dump([train, valid], open(metadata_file, "w"))
+ if distrib.world_size > 1:
+ distributed.barrier()
+ train, valid = json.load(open(metadata_file))
+ if args.full_cv:
+ kw_cv = {}
+ else:
+ kw_cv = {'segment': args.segment, 'shift': args.shift}
+ train_set = Wavset(train_path, train, args.sources,
+ segment=args.segment, shift=args.shift,
+ samplerate=args.samplerate, channels=args.channels,
+ normalize=args.normalize)
+ valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources),
+ samplerate=args.samplerate, channels=args.channels,
+ normalize=args.normalize, **kw_cv)
+ return train_set, valid_set
+
+
+def _get_musdb_valid():
+ # Return musdb valid set.
+ import yaml
+ setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml'
+ setup = yaml.safe_load(open(setup_path, 'r'))
+ return setup['validation_tracks']
+
+
+def get_musdb_wav_datasets(args):
+ """Extract the musdb dataset from the XP arguments."""
+ sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8]
+ metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json")
+ root = Path(args.musdb) / "train"
+ if not metadata_file.is_file() and distrib.rank == 0:
+ metadata_file.parent.mkdir(exist_ok=True, parents=True)
+ metadata = build_metadata(root, args.sources)
+ json.dump(metadata, open(metadata_file, "w"))
+ if distrib.world_size > 1:
+ distributed.barrier()
+ metadata = json.load(open(metadata_file))
+
+ valid_tracks = _get_musdb_valid()
+ if args.train_valid:
+ metadata_train = metadata
+ else:
+ metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks}
+ metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks}
+ if args.full_cv:
+ kw_cv = {}
+ else:
+ kw_cv = {'segment': args.segment, 'shift': args.shift}
+ train_set = Wavset(root, metadata_train, args.sources,
+ segment=args.segment, shift=args.shift,
+ samplerate=args.samplerate, channels=args.channels,
+ normalize=args.normalize)
+ valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources),
+ samplerate=args.samplerate, channels=args.channels,
+ normalize=args.normalize, **kw_cv)
+ return train_set, valid_set
diff --git a/demucs/wdemucs.py b/demucs/wdemucs.py
new file mode 100644
index 00000000..03d6dd3b
--- /dev/null
+++ b/demucs/wdemucs.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# For compat
+from .hdemucs import HDemucs
+
+WDemucs = HDemucs
diff --git a/hatch_build.py b/hatch_build.py
index b7267b85..b3469d36 100644
--- a/hatch_build.py
+++ b/hatch_build.py
@@ -66,6 +66,19 @@ class CustomBuildHook(BuildHookInterface):
print(result.stderr, file=sys.stderr)
print("Successfully built whisper.cpp binaries")
+ # Run the make command for translation files
+ result = subprocess.run(
+ ["make", "translation_mo"],
+ cwd=project_root,
+ check=True,
+ capture_output=True,
+ text=True
+ )
+ print(result.stdout)
+ if result.stderr:
+ print(result.stderr, file=sys.stderr)
+ print("Successfully compiled translation files")
+
# Force include all files in buzz/whisper_cpp directory
whisper_cpp_dir = project_root / "buzz" / "whisper_cpp"
if whisper_cpp_dir.exists():
@@ -88,6 +101,47 @@ class CustomBuildHook(BuildHookInterface):
else:
print(f"Warning: {whisper_cpp_dir} does not exist after build", file=sys.stderr)
+ # Force include all files in demucs directory
+ demucs_dir = project_root / "demucs"
+ if demucs_dir.exists():
+ # Get all files in the demucs directory
+ demucs_files = glob.glob(str(demucs_dir / "**" / "*"), recursive=True)
+
+ # Filter only files (not directories)
+ demucs_files = [f for f in demucs_files if Path(f).is_file()]
+
+ # Add them to force_include
+ if 'force_include' not in build_data:
+ build_data['force_include'] = {}
+
+ for file_path in demucs_files:
+ # Convert to relative path from project root
+ rel_path = Path(file_path).relative_to(project_root)
+ build_data['force_include'][str(rel_path)] = str(rel_path)
+
+ print(f"Force including {len(demucs_files)} files from demucs/")
+ else:
+ print(f"Warning: {demucs_dir} does not exist", file=sys.stderr)
+
+ # Force include all .mo files from buzz/locale directory
+ locale_dir = project_root / "buzz" / "locale"
+ if locale_dir.exists():
+ # Get all .mo files in the locale directory
+ locale_files = glob.glob(str(locale_dir / "**" / "*.mo"), recursive=True)
+
+ # Add them to force_include
+ if 'force_include' not in build_data:
+ build_data['force_include'] = {}
+
+ for file_path in locale_files:
+ # Convert to relative path from project root
+ rel_path = Path(file_path).relative_to(project_root)
+ build_data['force_include'][str(rel_path)] = str(rel_path)
+
+ print(f"Force including {len(locale_files)} .mo files from buzz/locale/")
+ else:
+ print(f"Warning: {locale_dir} does not exist", file=sys.stderr)
+
except subprocess.CalledProcessError as e:
print(f"Error building whisper.cpp: {e}", file=sys.stderr)
print(f"stdout: {e.stdout}", file=sys.stderr)
diff --git a/pyproject.toml b/pyproject.toml
index 144be849..01894149 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "buzz-captions"
-version = "1.3.2"
+version = "1.3.3"
description = ""
authors = [{ name = "Chidi Williams", email = "williamschidi1@gmail.com" }]
requires-python = ">=3.12,<3.13"
@@ -56,7 +56,6 @@ dependencies = [
"treetable>=0.2.5,<0.3",
"soundfile>=0.13.1,<0.14",
"urllib3>=2.3.0,<3",
- "demucs @ https://github.com/raivisdejus/demucs/releases/download/4.1.0a3/demucs-4.1.0a3-py3-none-any.whl",
"posthog>=3.23.0,<4",
"onnxruntime==1.18.1",
"vulkan>=1.3.275.1,<2",
@@ -131,6 +130,7 @@ include = [
"buzz",
"buzz/whisper_cpp/*",
"buzz/locale/*/LC_MESSAGES/buzz.mo",
+ "demucs",
]
[tool.hatch.build.targets.wheel]
@@ -138,6 +138,7 @@ include = [
"buzz",
"buzz/whisper_cpp/*",
"buzz/locale/*/LC_MESSAGES/buzz.mo",
+ "demucs",
]
[tool.hatch.build.hooks.custom]
diff --git a/pytest.ini b/pytest.ini
index b1ef248a..abd57212 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -4,7 +4,7 @@ log_cli_level = DEBUG
qt_api=pyqt6
log_format = %(asctime)s %(levelname)s %(module)s::%(funcName)s %(message)s
log_date_format = %Y-%m-%d %H:%M:%S
-addopts = -x
+addopts = -x -p no:xdist -p no:pytest_parallel
timeout = 600
timeout_method = thread
markers =
diff --git a/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml b/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
index f03483ac..5faf4bcc 100644
--- a/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
+++ b/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
@@ -64,8 +64,8 @@
-
- https://github.com/chidiwilliams/buzz/releases/tag/v1.3.2
+
+ https://github.com/chidiwilliams/buzz/releases/tag/v1.3.3
This release introduces Vulkan GPU support for whisper.cpp making it significantly faster even on laptops.
Real-time transcription is possible even with large models on computers with ~5GB RAM video cards. There
@@ -77,7 +77,7 @@
Option to switch the UI language from preferences
Library updates for better Linux compatibility, especially in Flatpak installations
Option to upload live transcripts to a server
-
Search and additional controls in Transcription viewer by [@shlomi-dr](https://github.com/shlomi-dr)
+
Search and additional controls in Transcription viewer
Added UI translation for German, Dutch, Danish and Portuguese (Brazilian)
Minor bug fixes
diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml
index bb7ae97b..8017d213 100644
--- a/snap/snapcraft.yaml
+++ b/snap/snapcraft.yaml
@@ -119,8 +119,8 @@ parts:
uv cache clean
# Copy source files
- mkdir -p $CRAFT_PART_INSTALL/buzz
cp -r $CRAFT_PART_BUILD/buzz $CRAFT_PART_INSTALL/
+ cp -r $CRAFT_PART_BUILD/demucs $CRAFT_PART_INSTALL/
# Create desktop file
mkdir -p $CRAFT_PART_INSTALL/usr/share/applications
diff --git a/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py b/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
index 8d34460c..c007caf4 100644
--- a/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
+++ b/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
@@ -778,23 +778,24 @@ class TestTranscriptionViewerWidgetAdditional:
widget.close()
- def test_run_translation(self, qtbot: QtBot, transcription, transcription_service, shortcuts):
- """Test run_translation method"""
- widget = TranscriptionViewerWidget(
- transcription, transcription_service, shortcuts
- )
- qtbot.add_widget(widget)
-
- # Set required options
- widget.transcription_options.llm_model = "gpt-4"
- widget.transcription_options.llm_prompt = "Translate"
-
- widget.run_translation()
-
- # Should enqueue translation tasks
- assert hasattr(widget, 'run_translation')
-
- widget.close()
+ # Skipped as it seems it is sending actual requests and maybe failing on CI
+ # def test_run_translation(self, qtbot: QtBot, transcription, transcription_service, shortcuts):
+ # """Test run_translation method"""
+ # widget = TranscriptionViewerWidget(
+ # transcription, transcription_service, shortcuts
+ # )
+ # qtbot.add_widget(widget)
+ #
+ # # Set required options
+ # widget.transcription_options.llm_model = "gpt-4"
+ # widget.transcription_options.llm_prompt = "Translate"
+ #
+ # widget.run_translation()
+ #
+ # # Should enqueue translation tasks
+ # assert hasattr(widget, 'run_translation')
+ #
+ # widget.close()
def test_restore_ui_state(self, qtbot: QtBot, transcription, transcription_service, shortcuts):
"""Test restore_ui_state method"""
diff --git a/uv.lock b/uv.lock
index f09bc661..0c326aa8 100644
--- a/uv.lock
+++ b/uv.lock
@@ -128,7 +128,7 @@ wheels = [
[[package]]
name = "buzz-captions"
-version = "1.3.2"
+version = "1.3.3"
source = { editable = "." }
dependencies = [
{ name = "accelerate" },
@@ -137,7 +137,6 @@ dependencies = [
{ name = "ctranslate2", version = "4.6.0", source = { registry = "https://pypi.org/simple/" }, marker = "platform_machine == 'arm64' or sys_platform != 'darwin'" },
{ name = "darkdetect" },
{ name = "dataclasses-json" },
- { name = "demucs" },
{ name = "diffq" },
{ name = "dora-search" },
{ name = "einops" },
@@ -218,7 +217,6 @@ requires-dist = [
{ name = "ctranslate2", marker = "platform_machine == 'x86_64' and sys_platform == 'darwin'", specifier = "==4.3.1" },
{ name = "darkdetect", specifier = ">=0.8.0,<0.9" },
{ name = "dataclasses-json", specifier = ">=0.6.4,<0.7" },
- { name = "demucs", url = "https://github.com/raivisdejus/demucs/releases/download/4.1.0a3/demucs-4.1.0a3-py3-none-any.whl" },
{ name = "diffq", specifier = ">=0.2.4,<0.3" },
{ name = "dora-search", specifier = ">=0.1.12,<0.2" },
{ name = "einops", specifier = ">=0.8.1,<0.9" },
@@ -566,62 +564,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" },
]
-[[package]]
-name = "demucs"
-version = "4.1.0a3"
-source = { url = "https://github.com/raivisdejus/demucs/releases/download/4.1.0a3/demucs-4.1.0a3-py3-none-any.whl" }
-dependencies = [
- { name = "dora-search" },
- { name = "einops" },
- { name = "julius" },
- { name = "lameenc" },
- { name = "openunmix" },
- { name = "pyyaml" },
- { name = "torch", version = "2.2.2", source = { registry = "https://pypi.org/simple/" }, marker = "platform_machine == 'x86_64' and sys_platform == 'darwin'" },
- { name = "torch", version = "2.7.1", source = { registry = "https://pypi.org/simple/" }, marker = "platform_machine != 'x86_64' and sys_platform == 'darwin'" },
- { name = "torch", version = "2.7.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" },
- { name = "torchaudio", version = "2.2.2", source = { registry = "https://pypi.org/simple/" }, marker = "platform_machine == 'x86_64' and sys_platform == 'darwin'" },
- { name = "torchaudio", version = "2.7.1", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'" },
- { name = "torchaudio", version = "2.7.1", source = { registry = "https://pypi.org/simple/" }, marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
- { name = "torchaudio", version = "2.7.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine != 'arm64' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_python_implementation != 'CPython' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
- { name = "tqdm" },
-]
-wheels = [
- { url = "https://github.com/raivisdejus/demucs/releases/download/4.1.0a3/demucs-4.1.0a3-py3-none-any.whl", hash = "sha256:3c52712c0b6022f7e26a00b0cfb4e4ed04ed9994f78f06cfa485dc7006cbef60" },
-]
-
-[package.metadata]
-requires-dist = [
- { name = "diffq", marker = "extra == 'dev'", specifier = ">=0.2.1" },
- { name = "dora-search" },
- { name = "dora-search", marker = "extra == 'dev'", specifier = ">=0.1.12" },
- { name = "einops" },
- { name = "einops", marker = "extra == 'dev'" },
- { name = "flake8", marker = "extra == 'dev'" },
- { name = "hydra-colorlog", marker = "extra == 'dev'", specifier = ">=1.1" },
- { name = "hydra-core", marker = "extra == 'dev'", specifier = ">=1.1" },
- { name = "julius", specifier = ">=0.2.3" },
- { name = "julius", marker = "extra == 'dev'", specifier = ">=0.2.3" },
- { name = "lameenc", specifier = ">=1.2" },
- { name = "lameenc", marker = "extra == 'dev'", specifier = ">=1.2" },
- { name = "museval", marker = "extra == 'dev'" },
- { name = "mypy", marker = "extra == 'dev'" },
- { name = "openunmix" },
- { name = "openunmix", marker = "extra == 'dev'" },
- { name = "pyyaml" },
- { name = "pyyaml", marker = "extra == 'dev'" },
- { name = "soundfile", marker = "extra == 'dev'", specifier = ">=0.10.3" },
- { name = "submitit", marker = "extra == 'dev'" },
- { name = "torch", specifier = ">=1.8.1" },
- { name = "torch", marker = "extra == 'dev'", specifier = ">=1.8.1" },
- { name = "torchaudio", specifier = ">=0.8" },
- { name = "torchaudio", marker = "extra == 'dev'", specifier = ">=0.8" },
- { name = "tqdm" },
- { name = "tqdm", marker = "extra == 'dev'" },
- { name = "treetable", marker = "extra == 'dev'" },
-]
-provides-extras = ["dev"]
-
[[package]]
name = "diffq"
version = "0.2.4"
From ccdeb09ac9030d74a768dece17af049ae0b32f37 Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Sun, 9 Nov 2025 11:52:20 +0200
Subject: [PATCH 03/73] Fix for translator test (#1280)
---
.../local_whisper_cpp_server_transcriber.py | 3 +-
.../openai_whisper_api_file_transcriber.py | 3 +-
buzz/transcriber/recording_transcriber.py | 3 +-
buzz/translator.py | 27 ++++++++------
.../general_preferences_widget.py | 2 +-
.../transcription_viewer_widget.py | 11 ++++--
tests/translator_test.py | 17 +++++----
...scription_viewer_widget_additional_test.py | 36 +++++++++----------
8 files changed, 58 insertions(+), 44 deletions(-)
diff --git a/buzz/transcriber/local_whisper_cpp_server_transcriber.py b/buzz/transcriber/local_whisper_cpp_server_transcriber.py
index c58553d9..d57252fe 100644
--- a/buzz/transcriber/local_whisper_cpp_server_transcriber.py
+++ b/buzz/transcriber/local_whisper_cpp_server_transcriber.py
@@ -64,7 +64,8 @@ class LocalWhisperCppServerTranscriber(OpenAIWhisperAPIFileTranscriber):
self.openai_client = OpenAI(
api_key="not-used",
- base_url="http://127.0.0.1:3000"
+ base_url="http://127.0.0.1:3000",
+ max_retries=0
)
def transcribe(self) -> List[Segment]:
diff --git a/buzz/transcriber/openai_whisper_api_file_transcriber.py b/buzz/transcriber/openai_whisper_api_file_transcriber.py
index 21a6652f..b2f02898 100644
--- a/buzz/transcriber/openai_whisper_api_file_transcriber.py
+++ b/buzz/transcriber/openai_whisper_api_file_transcriber.py
@@ -46,7 +46,8 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
self.task = task.transcription_options.task
self.openai_client = OpenAI(
api_key=self.transcription_task.transcription_options.openai_access_token,
- base_url=custom_openai_base_url if custom_openai_base_url else None
+ base_url=custom_openai_base_url if custom_openai_base_url else None,
+ max_retries=0
)
self.whisper_api_model = get_custom_api_whisper_model(custom_openai_base_url)
self.word_level_timings = self.transcription_task.transcription_options.word_level_timings
diff --git a/buzz/transcriber/recording_transcriber.py b/buzz/transcriber/recording_transcriber.py
index 8e5cc3d1..7867e50e 100644
--- a/buzz/transcriber/recording_transcriber.py
+++ b/buzz/transcriber/recording_transcriber.py
@@ -126,7 +126,8 @@ class RecordingTranscriber(QObject):
self.whisper_api_model = get_custom_api_whisper_model(custom_openai_base_url)
self.openai_client = OpenAI(
api_key=self.transcription_options.openai_access_token,
- base_url=custom_openai_base_url if custom_openai_base_url else None
+ base_url=custom_openai_base_url if custom_openai_base_url else None,
+ max_retries=0
)
logging.debug("Will use whisper API on %s, %s",
custom_openai_base_url, self.whisper_api_model)
diff --git a/buzz/translator.py b/buzz/translator.py
index 56a816ea..ffeecf7b 100644
--- a/buzz/translator.py
+++ b/buzz/translator.py
@@ -3,7 +3,7 @@ import logging
import queue
from typing import Optional
-from openai import OpenAI
+from openai import OpenAI, max_retries
from PyQt6.QtCore import QObject, pyqtSignal
from buzz.settings.settings import Settings
@@ -15,7 +15,6 @@ from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDi
class Translator(QObject):
translation = pyqtSignal(str, int)
finished = pyqtSignal()
- is_running = False
def __init__(
self,
@@ -48,19 +47,22 @@ class Translator(QObject):
)
self.openai_client = OpenAI(
api_key=openai_api_key,
- base_url=custom_openai_base_url if custom_openai_base_url else None
+ base_url=custom_openai_base_url if custom_openai_base_url else None,
+ max_retries=0
)
def start(self):
logging.debug("Starting translation queue")
- self.is_running = True
+ while True:
+ item = self.queue.get() # Block until item available
- while self.is_running:
- try:
- transcript, transcript_id = self.queue.get(timeout=1)
- except queue.Empty:
- continue
+ # Check for sentinel value (None means stop)
+ if item is None:
+ logging.debug("Translation queue received stop signal")
+ break
+
+ transcript, transcript_id = item
try:
completion = self.openai_client.chat.completions.create(
@@ -69,7 +71,8 @@ class Translator(QObject):
{"role": "system", "content": self.transcription_options.llm_prompt},
{"role": "user", "content": transcript}
],
- timeout=30.0
+ timeout=30.0,
+
)
except Exception as e:
completion = None
@@ -84,6 +87,7 @@ class Translator(QObject):
self.translation.emit(next_translation, transcript_id)
+ logging.debug("Translation queue stopped")
self.finished.emit()
def on_transcription_options_changed(
@@ -95,4 +99,5 @@ class Translator(QObject):
self.queue.put((transcript, transcript_id))
def stop(self):
- self.is_running = False
+ # Send sentinel value to unblock and stop the worker thread
+ self.queue.put(None)
diff --git a/buzz/widgets/preferences_dialog/general_preferences_widget.py b/buzz/widgets/preferences_dialog/general_preferences_widget.py
index 5cefcdaa..b7bdfc74 100644
--- a/buzz/widgets/preferences_dialog/general_preferences_widget.py
+++ b/buzz/widgets/preferences_dialog/general_preferences_widget.py
@@ -328,7 +328,7 @@ class ValidateOpenAIApiKeyJob(QRunnable):
client = OpenAI(
api_key=self.api_key,
base_url=custom_openai_base_url if custom_openai_base_url else None,
- timeout=5,
+ timeout=15,
)
client.models.list()
self.signals.success.emit()
diff --git a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
index bf4400b3..e77c2179 100644
--- a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
+++ b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
@@ -1351,8 +1351,15 @@ class TranscriptionViewerWidget(QWidget):
# Only wait if thread is actually running
if self.translation_thread.isRunning():
- if not self.translation_thread.wait(45_000):
- logging.warning("Translation thread did not finish within timeout")
+ # Wait up to 35 seconds for graceful shutdown
+ # (30s max API call timeout + 5s buffer)
+ if not self.translation_thread.wait(35_000):
+ logging.warning("Translation thread did not finish gracefully, terminating")
+ # Force terminate the thread if it doesn't stop
+ self.translation_thread.terminate()
+ # Give it a brief moment to terminate
+ if not self.translation_thread.wait(1_000):
+ logging.error("Translation thread could not be terminated")
super().closeEvent(event)
diff --git a/tests/translator_test.py b/tests/translator_test.py
index 6c0f87d6..c9b4d8e3 100644
--- a/tests/translator_test.py
+++ b/tests/translator_test.py
@@ -15,14 +15,12 @@ class TestTranslator:
@patch('buzz.translator.queue.Queue', autospec=True)
def test_start(self, mock_queue, mock_openai, qtbot):
def side_effect(*args, **kwargs):
- side_effect.call_count += 1
+ if side_effect.call_count <= 1:
+ side_effect.call_count += 1
+ return ("Hello, how are you?", 1)
- if side_effect.call_count >= 5:
- translator.is_running = False
-
- if side_effect.call_count < 3:
- raise Empty
- return "Hello, how are you?", None
+ # Finally return sentinel to stop
+ return None
side_effect.call_count = 0
@@ -51,6 +49,8 @@ class TestTranslator:
mock_queue.get.assert_called()
mock_chat.completions.create.assert_called()
+ translator.stop()
+
@patch('buzz.translator.OpenAI', autospec=True)
def test_translator(self, mock_openai, qtbot):
@@ -94,8 +94,7 @@ class TestTranslator:
self.translation_thread.start()
- time.sleep(3)
- assert self.translator.is_running
+ time.sleep(1) # Give thread time to start
self.translator.enqueue("Hello, how are you?")
diff --git a/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py b/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
index c007caf4..9e716e7a 100644
--- a/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
+++ b/tests/widgets/transcription_viewer/transcription_viewer_widget_additional_test.py
@@ -778,24 +778,24 @@ class TestTranscriptionViewerWidgetAdditional:
widget.close()
- # Skipped as it seems it is sending actual requests and maybe failing on CI
- # def test_run_translation(self, qtbot: QtBot, transcription, transcription_service, shortcuts):
- # """Test run_translation method"""
- # widget = TranscriptionViewerWidget(
- # transcription, transcription_service, shortcuts
- # )
- # qtbot.add_widget(widget)
- #
- # # Set required options
- # widget.transcription_options.llm_model = "gpt-4"
- # widget.transcription_options.llm_prompt = "Translate"
- #
- # widget.run_translation()
- #
- # # Should enqueue translation tasks
- # assert hasattr(widget, 'run_translation')
- #
- # widget.close()
+ # TODO - it is sending actual requests, should mock
+ def test_run_translation(self, qtbot: QtBot, transcription, transcription_service, shortcuts):
+ """Test run_translation method"""
+ widget = TranscriptionViewerWidget(
+ transcription, transcription_service, shortcuts
+ )
+ qtbot.add_widget(widget)
+
+ # Set required options
+ widget.transcription_options.llm_model = "gpt-4"
+ widget.transcription_options.llm_prompt = "Translate"
+
+ widget.run_translation()
+
+ # Should enqueue translation tasks
+ assert hasattr(widget, 'run_translation')
+
+ widget.close()
def test_restore_ui_state(self, qtbot: QtBot, transcription, transcription_service, shortcuts):
"""Test restore_ui_state method"""
From 070d9f17d576716d98e1945425bb7414011d14a4 Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Sun, 9 Nov 2025 21:57:39 +0200
Subject: [PATCH 04/73] Documentation adjustments (#1281)
---
.github/workflows/ci.yml | 73 ++++++++++---------
README.md | 48 ++++++------
docs/docs/faq.md | 2 +-
.../io.github.chidiwilliams.Buzz.metainfo.xml | 2 +-
4 files changed, 60 insertions(+), 65 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index dbfa02f0..010e183a 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -357,40 +357,41 @@ jobs:
with:
files: |
Buzz*-unix.tar.gz
- Buzz*-windows.exe
- Buzz*-windows-*.bin
- Buzz*-mac.dmg
+ Buzz*.exe
+ Buzz*.bin
+ Buzz*.dmg
- deploy_brew_cask:
- runs-on: macos-latest
- env:
- BUZZ_DISABLE_TELEMETRY: true
- needs: [release]
- if: startsWith(github.ref, 'refs/tags/')
- steps:
- - uses: actions/checkout@v4
- with:
- submodules: recursive
-
- # Should be removed with next update to whisper.cpp
- - name: Downgrade Xcode
- uses: maxim-lobanov/setup-xcode@v1
- with:
- xcode-version: '16.0.0'
- if: matrix.os == 'macos-latest'
-
- - name: Install uv
- uses: astral-sh/setup-uv@v6
-
- - name: Set up Python
- uses: actions/setup-python@v5
- with:
- python-version: "3.12"
-
- - name: Install dependencies
- run: uv sync
-
- - name: Upload to Brew
- run: uv run make upload_brew
- env:
- HOMEBREW_GITHUB_API_TOKEN: ${{ secrets.HOMEBREW_GITHUB_API_TOKEN }}
+# Brew Cask deployment fails and the app is deprecated on Brew.
+# deploy_brew_cask:
+# runs-on: macos-latest
+# env:
+# BUZZ_DISABLE_TELEMETRY: true
+# needs: [release]
+# if: startsWith(github.ref, 'refs/tags/')
+# steps:
+# - uses: actions/checkout@v4
+# with:
+# submodules: recursive
+#
+# # Should be removed with next update to whisper.cpp
+# - name: Downgrade Xcode
+# uses: maxim-lobanov/setup-xcode@v1
+# with:
+# xcode-version: '16.0.0'
+# if: matrix.os == 'macos-latest'
+#
+# - name: Install uv
+# uses: astral-sh/setup-uv@v6
+#
+# - name: Set up Python
+# uses: actions/setup-python@v5
+# with:
+# python-version: "3.12"
+#
+# - name: Install dependencies
+# run: uv sync
+#
+# - name: Upload to Brew
+# run: uv run make upload_brew
+# env:
+# HOMEBREW_GITHUB_API_TOKEN: ${{ secrets.HOMEBREW_GITHUB_API_TOKEN }}
diff --git a/README.md b/README.md
index 173d25e4..55c62f9d 100644
--- a/README.md
+++ b/README.md
@@ -22,26 +22,9 @@ OpenAI's [Whisper](https://github.com/openai/whisper).
## Installation
-### PyPI
-
-Install [ffmpeg](https://www.ffmpeg.org/download.html)
-
-Install Buzz
-
-```shell
-pip install buzz-captions
-python -m buzz
-```
-
### macOS
-Install with [brew utility](https://brew.sh/)
-
-```shell
-brew install --cask buzz
-```
-
-Or download the `.dmg` from the [SourceForge](https://sourceforge.net/projects/buzz-captions/files/).
+Download the `.dmg` from the [SourceForge](https://sourceforge.net/projects/buzz-captions/files/).
### Windows
@@ -55,15 +38,6 @@ App is not signed, you will get a warning when you install it. Select `More info
winget install ChidiWilliams.Buzz
```
-**GPU support for PyPI**
-
-To have GPU support for Nvidia GPUS on Windows, for PyPI installed version ensure, CUDA support for [torch](https://pytorch.org/get-started/locally/)
-
-```
-pip3 install -U torch==2.7.1+cu128 torchaudio==2.7.1+cu128 --index-url https://download.pytorch.org/whl/cu128
-pip3 install nvidia-cublas-cu12==12.8.3.14 nvidia-cuda-cupti-cu12==12.8.57 nvidia-cuda-nvrtc-cu12==12.8.61 nvidia-cuda-runtime-cu12==12.8.57 nvidia-cudnn-cu12==9.7.1.26 nvidia-cufft-cu12==11.3.3.41 nvidia-curand-cu12==10.3.9.55 nvidia-cusolver-cu12==11.7.2.55 nvidia-cusparse-cu12==12.5.4.2 nvidia-cusparselt-cu12==0.6.3 nvidia-nvjitlink-cu12==12.8.61 nvidia-nvtx-cu12==12.8.55 --extra-index-url https://pypi.ngc.nvidia.com
-```
-
### Linux
Buzz is available as a [Flatpak](https://flathub.org/apps/io.github.chidiwilliams.Buzz) or a [Snap](https://snapcraft.io/buzz).
@@ -80,6 +54,26 @@ sudo snap install buzz
sudo snap connect buzz:password-manager-service
```
+### PyPI
+
+Install [ffmpeg](https://www.ffmpeg.org/download.html)
+
+Install Buzz
+
+```shell
+pip install buzz-captions
+python -m buzz
+```
+
+**GPU support for PyPI**
+
+To have GPU support for Nvidia GPUS on Windows, for PyPI installed version ensure, CUDA support for [torch](https://pytorch.org/get-started/locally/)
+
+```
+pip3 install -U torch==2.7.1+cu128 torchaudio==2.7.1+cu128 --index-url https://download.pytorch.org/whl/cu128
+pip3 install nvidia-cublas-cu12==12.8.3.14 nvidia-cuda-cupti-cu12==12.8.57 nvidia-cuda-nvrtc-cu12==12.8.61 nvidia-cuda-runtime-cu12==12.8.57 nvidia-cudnn-cu12==9.7.1.26 nvidia-cufft-cu12==11.3.3.41 nvidia-curand-cu12==10.3.9.55 nvidia-cusolver-cu12==11.7.2.55 nvidia-cusparse-cu12==12.5.4.2 nvidia-cusparselt-cu12==0.6.3 nvidia-nvjitlink-cu12==12.8.61 nvidia-nvtx-cu12==12.8.55 --extra-index-url https://pypi.ngc.nvidia.com
+```
+
### Latest development version
For info on how to get latest development version with latest features and bug fixes see [FAQ](https://chidiwilliams.github.io/buzz/docs/faq#9-where-can-i-get-latest-development-version).
diff --git a/docs/docs/faq.md b/docs/docs/faq.md
index 10d74409..4de7f377 100644
--- a/docs/docs/faq.md
+++ b/docs/docs/faq.md
@@ -84,7 +84,7 @@ gsettings set org.gnome.desktop.interface color-scheme prefer-dark
If your system theme is not applied to Buzz installed from Flatpak Linux app store, ensure the desired theme is in `~/.themes` folder.
-You may need to copy the system themes to this folder `cp -r /usr/share/themes/ ~/.themes/`.
+You may need to copy the system themes to this folder `cp -r /usr/share/themes/ ~/.themes/` and give Flatpaks access to this folder `flatpak override --user --filesystem=~/.themes`.
On Fedora run the following to install the necessary packages
`sudo dnf install gnome-themes-extra qadwaitadecorations-qt{5,6} qt{5,6}-qtwayland`
\ No newline at end of file
diff --git a/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml b/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
index 5faf4bcc..d65251fd 100644
--- a/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
+++ b/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
@@ -16,7 +16,7 @@
Required permissions in Buzz will let you select audio and video files for transcription, from most common file location on your computer. Network permission is used to download transcription model files. Microphone permission lets you transcribe real time speech.
- Note: If your system theme is not applied to Buzz, ensure it is in ~/.themes folder. You may need to copy the system themes to this folder cp -r /usr/share/themes/ ~/.themes/.
+ Note: If your system theme is not applied to Buzz, ensure it is in ~/.themes folder. You may need to copy the system themes to this folder cp -r /usr/share/themes/ ~/.themes/ and give Flatpaks access to this folder flatpak override --user --filesystem=~/.themes.
From 629fa9f1f7c10bef09f608aef6c911dfb8b930fb Mon Sep 17 00:00:00 2001
From: albanobattistella <34811668+albanobattistella@users.noreply.github.com>
Date: Sun, 9 Nov 2025 21:36:33 +0100
Subject: [PATCH 05/73] Update buzz.po (#1282)
---
buzz/locale/it_IT/LC_MESSAGES/buzz.po | 87 ++++++++++++++-------------
1 file changed, 45 insertions(+), 42 deletions(-)
diff --git a/buzz/locale/it_IT/LC_MESSAGES/buzz.po b/buzz/locale/it_IT/LC_MESSAGES/buzz.po
index 2159b4b4..b1206756 100644
--- a/buzz/locale/it_IT/LC_MESSAGES/buzz.po
+++ b/buzz/locale/it_IT/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgstr ""
"Project-Id-Version: buzz\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2025-10-12 19:10+0300\n"
-"PO-Revision-Date: 2025-05-30 15:22+0100\n"
+"PO-Revision-Date: 2025-11-09 20:22+0200\n"
"Language-Team: (Italiano) Albano Battistella \n"
"Language: it_IT\n"
"MIME-Version: 1.0\n"
@@ -46,7 +46,7 @@ msgstr "URL:"
#: buzz/widgets/import_url_dialog.py:44
msgid "Invalid URL"
-msgstr "URL non valido"
+msgstr "URL non valido"
#: buzz/widgets/import_url_dialog.py:44
msgid "The URL you entered is invalid."
@@ -107,8 +107,7 @@ msgid "Polish"
msgstr "Polacco"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:45
-#, fuzzy
-msgid "Portuguese (Brazil)"
+msgid "Portuguese"
msgstr "Portoghese"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:46
@@ -172,15 +171,15 @@ msgstr "Modalità di registrazione in diretta"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:183
msgid "Use only CPU and disable GPU acceleration"
-msgstr ""
+msgstr "Utilizza solo la CPU e disattiva l'accelerazione GPU"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:186
msgid "Set this if larger models do not fit your GPU memory and Buzz crashes"
-msgstr ""
+msgstr "Imposta questa opzione se i modelli più grandi non si adattano alla memoria della tua GPU e Buzz si blocca"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:188
msgid "Disable GPU"
-msgstr ""
+msgstr "Disabilita GPU"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:213
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:219
@@ -392,6 +391,8 @@ msgid ""
"Enter instructions for AI on how to translate, for example 'Please translate "
"each text sent to you from English to Spanish.'"
msgstr ""
+Inserisci le istruzioni per l'IA su come tradurre, ad esempio 'Per favore, traduci "
+"ogni testo che ti viene inviato dall'inglese allo spagnolo.'"
#: buzz/widgets/transcriber/advanced_settings_dialog.py:92
msgid "Instructions for AI:"
@@ -562,86 +563,88 @@ msgstr "Ridimensionare"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
msgid "Find"
-msgstr ""
+msgstr "Trova"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
msgid "Show/Hide Search Bar (Ctrl+F)"
-msgstr ""
+msgstr "Mostra/Nascondi barra di ricerca (Ctrl+F)"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
msgid "Find:"
-msgstr ""
+msgstr "Trova:"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
msgid "Enter text to find..."
-msgstr ""
+msgstr "Inserisci il testo per trovare..."
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
msgid "Previous match (Shift+Enter)"
-msgstr ""
+msgstr "Corrispondenza precedente (Maiusc+Invio)"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
msgid "Next match (Enter)"
-msgstr ""
+msgstr "Prossima corrispondenza (Invio)"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
msgid "Clear"
-msgstr ""
+msgstr "Elimina"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
msgid "Playback Controls:"
-msgstr ""
+msgstr "Controlli di riproduzione:"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
msgid "Loop Segment"
-msgstr ""
+msgstr "Ciclo di segmento"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
msgid "Enable/disable looping when clicking on transcript segments"
-msgstr ""
+msgstr "Abilita/disabilita il loop quando si fa clic sui segmenti della trascrizione"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
msgid "Follow Audio"
-msgstr ""
+msgstr "Segui Audio"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
+"Abilita/disabilita la lettura della posizione audio corrente nella trascrizione. Quando "
+"abilitato, scorre automaticamente fino al testo corrente."
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
msgid "Scroll to Current"
-msgstr ""
+msgstr "Scorri fino al Corrente"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
msgid "Scroll to the currently spoken text"
-msgstr ""
+msgstr "Scorrere fino al testo attualmente pronunciato"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
msgid "1 of 100+ matches"
-msgstr ""
+msgstr "1 di 100+ corrispondenze"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
msgid "1 of "
-msgstr ""
+msgstr "1 di"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
msgid " matches"
-msgstr ""
+msgstr "corrispondenze"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
msgid "No matches found"
-msgstr ""
+msgstr "Nessuna corrispondenza trovata"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
msgid " of 100+ matches"
-msgstr ""
+msgstr " di oltre 100 corrispondenze"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
msgid " of "
-msgstr ""
+msgstr " di "
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
msgid "API Key Required"
@@ -761,7 +764,7 @@ msgstr "Impossibile salvare la chiave API OpenAI nel portachiavi"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
#: buzz/transcriber/recording_transcriber.py:394
msgid "Whisper server failed to start. Check logs for details."
-msgstr ""
+msgstr "Impossibile avviare il server Whisper. Controllare i log per i dettagli."
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
#: buzz/transcriber/recording_transcriber.py:398
@@ -770,11 +773,13 @@ msgid ""
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
"variable."
msgstr ""
+"Impossibile avviare il server Whisper a causa di memoria insufficiente. Riprovare "
+"con un modello più piccolo. Per forzare la modalità CPU, utilizzare la variabile d'ambiente "
+"BUZZ_FORCE_CPU=TRUE"
#: buzz/transcriber/transcriber.py:24
-#, fuzzy
msgid "Translate to English"
-msgstr "Impostazioni di traduzione"
+msgstr "Traduci in inglese"
#: buzz/transcriber/transcriber.py:25
msgid "Transcribe"
@@ -1142,12 +1147,11 @@ msgstr "Si è verificato un errore di connessione"
#: buzz/transcriber/recording_transcriber.py:332
msgid "Starting Whisper.cpp..."
-msgstr ""
+msgstr "Avvio di Whisper.cpp..."
#: buzz/transcriber/recording_transcriber.py:385
-#, fuzzy
msgid "Starting transcription..."
-msgstr "Annulla trascrizione"
+msgstr "Inizio trascrizione..."
#: buzz/settings/shortcut.py:17
msgid "Open Record Window"
@@ -1174,41 +1178,40 @@ msgid "View Transcript Timestamps"
msgstr "Visualizza i timestamp della trascrizione"
#: buzz/settings/shortcut.py:25
-#, fuzzy
msgid "Search Transcript"
-msgstr "Apri trascrizione"
+msgstr "Cerca trascrizione"
#: buzz/settings/shortcut.py:26
msgid "Scroll to Current Text"
-msgstr ""
+msgstr "Scorri fino al testo corrente"
#: buzz/settings/shortcut.py:27
msgid "Play/Pause Audio"
-msgstr ""
+msgstr "Riproduci/Pausa audio"
#: buzz/settings/shortcut.py:28
msgid "Replay Current Segment"
-msgstr ""
+msgstr "Riproduci il segmento corrente"
#: buzz/settings/shortcut.py:29
msgid "Toggle Playback Controls"
-msgstr ""
+msgstr "Attiva/disattiva i controlli di riproduzione"
#: buzz/settings/shortcut.py:31
msgid "Decrease Segment Start Time"
-msgstr ""
+msgstr "Riduci l'ora di inizio del segmento"
#: buzz/settings/shortcut.py:32
msgid "Increase Segment Start Time"
-msgstr ""
+msgstr "Aumenta l'ora di inizio del segmento"
#: buzz/settings/shortcut.py:33
msgid "Decrease Segment End Time"
-msgstr ""
+msgstr "Diminuisci l'ora di fine del segmento"
#: buzz/settings/shortcut.py:34
msgid "Increase Segment End Time"
-msgstr ""
+msgstr "Aumenta l'ora di fine del segmento"
#: buzz/settings/recording_transcriber_mode.py:5
msgid "Append below"
From 93559530abcd95393cbe52b478598c93dfc5d1b0 Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Mon, 17 Nov 2025 22:53:06 +0200
Subject: [PATCH 06/73] Adjusting flatpak meta (#1285)
---
CONTRIBUTING.md | 3 +-
io.github.chidiwilliams.Buzz.yml | 90 -------------------
share/icons/io.github.chidiwilliams.Buzz.svg | 37 ++++----
.../io.github.chidiwilliams.Buzz.metainfo.xml | 2 +-
4 files changed, 25 insertions(+), 107 deletions(-)
delete mode 100644 io.github.chidiwilliams.Buzz.yml
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index ea9fb22e..43df6166 100755
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -28,7 +28,8 @@ What version of the Buzz are you using? On what OS? What are steps to reproduce
**Logs**
Log files contain valuable information about what the Buzz was doing before the issue occurred. You can get the logs like this:
-* Mac and Linux run the app from the terminal and check the output.
+* Linux run the app from the terminal and check the output.
+* Mac get logs from `~/Library/Logs/Buzz`.
* Windows paste this into the Windows Explorer address bar `%USERPROFILE%\AppData\Local\Buzz\Buzz\Logs` and check the logs file.
**Test on latest version**
diff --git a/io.github.chidiwilliams.Buzz.yml b/io.github.chidiwilliams.Buzz.yml
deleted file mode 100644
index 10536b23..00000000
--- a/io.github.chidiwilliams.Buzz.yml
+++ /dev/null
@@ -1,90 +0,0 @@
-# Building notes:
-# See https://docs.flathub.org/docs/for-app-authors/submission/
-# This flatpak is build from the snap package.
-# - Get relevant snap package infor - curl -H 'Snap-Device-Series: 16' http://api.snapcraft.io/v2/snaps/info/buzz # | jq
-# - Download snap and generate sha256sum, update yaml entry.
-
-app-id: io.github.chidiwilliams.Buzz
-runtime: org.freedesktop.Platform
-# TODO - Update to 24.08 when snap is updated to core24
-runtime-version: '22.08' # To match `core22` of the snap
-sdk: org.freedesktop.Sdk
-command: run-buzz.sh
-finish-args:
- - --socket=wayland
- - --socket=fallback-x11
- - --socket=pulseaudio
- - --talk-name=org.freedesktop.secrets
- - --device=dri
- # TODO switch 'all' to input when it is widely available
- #- --device=input
- - --device=all
- - --share=network
- - --share=ipc
- - --filesystem=xdg-documents
- # Environment variables
- - --env=LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/app/lib/python3.10/site-packages/nvidia/cudnn/lib:/app/lib/python3.10/site-packages/PyQt6:/app/lib/python3.10/site-packages/PyQt6/Qt6/lib:/app/usr/lib/x86_64-linux-gnu/lapack:/app/usr/lib/x86_64-linux-gnu/blas:/app/usr/lib/x86_64-linux-gnu/pulseaudio:/app/usr/lib/x86_64-linux-gnu:/app/lib/x86_64-linux-gnu/
- - --env=PYTHONPATH=$PYTHONPATH:/app/lib/python3.10/site-packages:/app/lib/python3.10/site-packages/PyQt6:/app/lib/python3.10/site-packages/PyQt6/Qt6/lib
-
-modules:
- - name: unsquashfs
- buildsystem: simple
- build-commands:
- - XZ_SUPPORT=1 make -C squashfs-tools -j ${FLATPAK_BUILDER_N_JOBS} unsquashfs
- - install -Dpm755 -t "${FLATPAK_DEST}/bin" squashfs-tools/unsquashfs
- sources:
- - type: git
- url: https://github.com/plougher/squashfs-tools.git
- tag: 4.6.1
- commit: d8cb82d9840330f9344ec37b992595b5d7b44184
-
- - name: snap
- buildsystem: simple
- build-commands:
- - unsquashfs -dest buzz -quiet -no-progress buzz.snap
- - cp -rT buzz ${FLATPAK_DEST} && rm -rf buzz
- sources:
- - type: file
- dest-filename: buzz.snap
- # Stable 1.2.0
- url: https://api.snapcraft.io/api/v1/snaps/download/RSpCVxCNDwoTXHPXhlYQnziD0jQhVnKA_362.snap
- sha256: fbc045426c867b1d7ee01178d4f53d785c161709e2a9db6854cefec29aa510d7
- # Edge
- #url: https://api.snapcraft.io/api/v1/snaps/download/RSpCVxCNDwoTXHPXhlYQnziD0jQhVnKA_402.snap
- #sha256: 0acecacf8fa476bf6d7afcd98b7b557829b70cfa8b1d57e6ff5248737b63ab60
-
- # Borrowed from https://github.com/flathub/org.audacityteam.Audacity/blob/master/org.audacityteam.Audacity.yaml
- - name: portaudio
- buildsystem: cmake-ninja
- config-opts:
- - -DCMAKE_BUILD_TYPE=RelWithDebInfo
- sources:
- - type: archive
- url: https://github.com/PortAudio/portaudio/archive/refs/tags/v19.7.0.tar.gz
- sha256: 5af29ba58bbdbb7bbcefaaecc77ec8fc413f0db6f4c4e286c40c3e1b83174fa0
-
- # Borrowed from https://github.com/flathub/org.freedownloadmanager.Manager/pull/20/files
- - name: kerberos
- subdir: src
- sources:
- - type: archive
- url: https://kerberos.org/dist/krb5/1.21/krb5-1.21.tar.gz
- sha256: 69f8aaff85484832df67a4bbacd99b9259bd95aab8c651fbbe65cdc9620ea93b
-
- - name: Buzz
- buildsystem: simple
- build-commands:
- - install -Dm755 flatpak/run-buzz.sh ${FLATPAK_DEST}/bin/run-buzz.sh
-
- - install -Dm644 share/icons/${FLATPAK_ID}.svg ${FLATPAK_DEST}/share/icons/hicolor/scalable/apps/${FLATPAK_ID}.svg
- - install -Dm644 share/applications/${FLATPAK_ID}.desktop ${FLATPAK_DEST}/share/applications/${FLATPAK_ID}.desktop
- - install -Dm644 share/metainfo/${FLATPAK_ID}.metainfo.xml ${FLATPAK_DEST}/share/metainfo/${FLATPAK_ID}.metainfo.xml
-
- - install -Dm644 flatpak/libbsd.so.0 ${FLATPAK_DEST}/lib/x86_64-linux-gnu/libbsd.so.0
- - install -Dm644 flatpak/libmd.so.0 ${FLATPAK_DEST}/lib/x86_64-linux-gnu/libmd.so.0
- - install -Dm644 flatpak/libdb-5.3.so ${FLATPAK_DEST}/lib/x86_64-linux-gnu/libdb-5.3.so
- - install -Dm644 flatpak/libapparmor.so.1 ${FLATPAK_DEST}/lib/x86_64-linux-gnu/libapparmor.so.1
- - install -Dm644 flatpak/libavutil.so.58 ${FLATPAK_DEST}/lib/x86_64-linux-gnu/libavutil.so.58
- sources:
- - type: dir
- path: .
diff --git a/share/icons/io.github.chidiwilliams.Buzz.svg b/share/icons/io.github.chidiwilliams.Buzz.svg
index 79604329..d5b67bc0 100644
--- a/share/icons/io.github.chidiwilliams.Buzz.svg
+++ b/share/icons/io.github.chidiwilliams.Buzz.svg
@@ -1,16 +1,23 @@
\ No newline at end of file
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml b/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
index d65251fd..b94e23bd 100644
--- a/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
+++ b/share/metainfo/io.github.chidiwilliams.Buzz.metainfo.xml
@@ -64,7 +64,7 @@
-
+ https://github.com/chidiwilliams/buzz/releases/tag/v1.3.3
This release introduces Vulkan GPU support for whisper.cpp making it significantly faster even on laptops.
From de1ed90f50eedef512c274b1a0468d76567e712f Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Tue, 18 Nov 2025 18:22:10 +0200
Subject: [PATCH 07/73] Fix for snap (#1286)
---
snap/snapcraft.yaml | 1 +
1 file changed, 1 insertion(+)
diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml
index 8017d213..346ff7e4 100644
--- a/snap/snapcraft.yaml
+++ b/snap/snapcraft.yaml
@@ -89,6 +89,7 @@ parts:
- libgstreamer1.0-0
- libgstreamer-plugins-base1.0-0
- libgstreamer-plugins-good1.0-0
+ - liboss4-salsa2
# Display
- libxkbcommon-x11-0
- libxcb-icccm4
From 5a81c715d1a6dc4ee768fe29b06fa1613a58056b Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Thu, 20 Nov 2025 07:50:56 +0200
Subject: [PATCH 08/73] Adjusting Windows build notes (#1288)
---
CONTRIBUTING.md | 34 ++++++++++-----------------
buzz/locale/it_IT/LC_MESSAGES/buzz.po | 6 +----
2 files changed, 13 insertions(+), 27 deletions(-)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 43df6166..d8a540cf 100755
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -94,16 +94,18 @@ Assumes you have [Git](https://git-scm.com/downloads) and [python](https://www.p
```
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
```
-2. Install the GNU make. `choco install make`
+2. Install the build tools. `choco install make cmake`
3. Install the ffmpeg. `choco install ffmpeg`
-4. Install [MSYS2](https://www.msys2.org/), follow [this guide](https://sajidifti.medium.com/how-to-install-gcc-and-gdb-on-windows-using-msys2-tutorial-0fceb7e66454).
-5. Clone the repository `git clone --recursive https://github.com/chidiwilliams/buzz.git`
-6. Enter repo folder `cd buzz`
-7. Install uv `powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"`
-8. Install the dependencies `uv sync`
-9. `cp -r .\dll_backup\ .\buzz\`
-10. Build Buzz `uv build`
-11. Run Buzz `uv run buzz`
+4. Download [Build Tools for Visual Studio 2022](https://visualstudio.microsoft.com/vs/older-downloads/) and install "Desktop development with C++" workload.
+5. Add location of `namke` to your PATH environment variable. Usually it is `C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\bin\Hostx64\x86`
+6. Install Vulkan SDK from https://vulkan.lunarg.com/sdk/home
+7. Clone the repository `git clone --recursive https://github.com/chidiwilliams/buzz.git`
+8. Enter repo folder `cd buzz`
+9. Install uv `powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"`
+10. Install the dependencies `uv sync`
+11. Build Whisper.cpp `uv run make buzz/whisper_cpp`
+12. `cp -r .\dll_backup\ .\buzz\`
+13. Run Buzz `uv run buzz`
Note: It should be safe to ignore any "syntax errors" you see during the build. Buzz will work. Also you can ignore any errors for FFmpeg. Buzz tries to load FFmpeg by several different means and some of them throw errors, but FFmpeg should eventually be found and work.
@@ -119,16 +121,4 @@ uv add --index https://pypi.ngc.nvidia.com nvidia-cublas-cu12==12.8.3.14 nvidia-
To use Faster Whisper on GPU, install the following libraries:
* [cuBLAS](https://developer.nvidia.com/cublas)
-* [cuDNN](https://developer.nvidia.com/cudnn)
-
-If you run into issues with FFmpeg, ensure ffmpeg dependencies are installed
-```
-pip3 uninstall ffmpeg ffmpeg-python
-pip3 install ffmpeg
-pip3 install ffmpeg-python
-```
-
-For Whisper.cpp you will need to install Vulkan SDK.
-Follow the instructions here https://vulkan.lunarg.com/doc/sdk/latest/windows/getting_started.html
-
-Run Buzz `python -m buzz`
\ No newline at end of file
+* [cuDNN](https://developer.nvidia.com/cudnn)
\ No newline at end of file
diff --git a/buzz/locale/it_IT/LC_MESSAGES/buzz.po b/buzz/locale/it_IT/LC_MESSAGES/buzz.po
index b1206756..f215fb66 100644
--- a/buzz/locale/it_IT/LC_MESSAGES/buzz.po
+++ b/buzz/locale/it_IT/LC_MESSAGES/buzz.po
@@ -106,10 +106,6 @@ msgstr "Lettone"
msgid "Polish"
msgstr "Polacco"
-#: buzz/widgets/preferences_dialog/general_preferences_widget.py:45
-msgid "Portuguese"
-msgstr "Portoghese"
-
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:46
#: buzz/transcriber/transcriber.py:59
msgid "Ukrainian"
@@ -391,7 +387,7 @@ msgid ""
"Enter instructions for AI on how to translate, for example 'Please translate "
"each text sent to you from English to Spanish.'"
msgstr ""
-Inserisci le istruzioni per l'IA su come tradurre, ad esempio 'Per favore, traduci "
+"Inserisci le istruzioni per l'IA su come tradurre, ad esempio 'Per favore, traduci "
"ogni testo che ti viene inviato dall'inglese allo spagnolo.'"
#: buzz/widgets/transcriber/advanced_settings_dialog.py:92
From f3765a586fe004c7c6bd906ca16f3e39e73984c8 Mon Sep 17 00:00:00 2001
From: David Olowomeye <100958002+greatdaveo@users.noreply.github.com>
Date: Mon, 24 Nov 2025 07:20:12 +0000
Subject: [PATCH 09/73] Implemented resume functionality for downloading models
#1287 (#1289)
---
buzz/model_loader.py | 215 +++++++++++++++++++++++++++++++++++--------
1 file changed, 175 insertions(+), 40 deletions(-)
diff --git a/buzz/model_loader.py b/buzz/model_loader.py
index 790dbbdf..ce12ba42 100644
--- a/buzz/model_loader.py
+++ b/buzz/model_loader.py
@@ -30,7 +30,6 @@ os.makedirs(model_root_dir, exist_ok=True)
logging.debug("Model root directory: %s", model_root_dir)
-
class WhisperModelSize(str, enum.Enum):
TINY = "tiny"
TINYEN = "tiny.en"
@@ -60,6 +59,25 @@ class WhisperModelSize(str, enum.Enum):
def __str__(self):
return self.value.capitalize()
+# Approximate expected file sizes for Whisper models
+WHISPER_MODEL_SIZES = {
+ WhisperModelSize.TINY: 75 * 1024 * 1024,
+ WhisperModelSize.TINYEN: 75 * 1024 * 1024,
+ WhisperModelSize.BASE: 150 * 1024 * 1024,
+ WhisperModelSize.BASEEN: 150 * 1024 * 1024,
+ WhisperModelSize.SMALL: 500 * 1024 * 1024,
+ WhisperModelSize.SMALLEN: 500 * 1024 * 1024,
+ WhisperModelSize.MEDIUM: 1500 * 1024 * 1024,
+ WhisperModelSize.MEDIUMEN: 1500 * 1024 * 1024,
+ WhisperModelSize.LARGE: 3100 * 1024 * 1024,
+ WhisperModelSize.LARGEV2: 3100 * 1024 * 1024,
+ WhisperModelSize.LARGEV3: 3100 * 1024 * 1024,
+ WhisperModelSize.LARGEV3TURBO: 3100 * 1024 * 1024,
+}
+
+def get_expected_whisper_model_size(size: WhisperModelSize) -> Optional[int]:
+ """Get expected file size for a Whisper model without network request."""
+ return WHISPER_MODEL_SIZES.get(size, None)
class ModelType(enum.Enum):
WHISPER = "Whisper"
@@ -200,7 +218,21 @@ class TranscriptionModel:
file_path = get_whisper_file_path(size=self.whisper_model_size)
if not os.path.exists(file_path) or not os.path.isfile(file_path):
return None
- return file_path
+
+ file_size = os.path.getsize(file_path)
+
+ expected_size = get_expected_whisper_model_size(self.whisper_model_size)
+
+ if expected_size is not None:
+ if file_size < expected_size * 0.95: # Allow 5% tolerance for file system differences
+ return None
+ return file_path
+ else:
+ # For unknown model size
+ if file_size < 50 * 1024 * 1024:
+ return None
+
+ return file_path
if self.model_type == ModelType.FASTER_WHISPER:
try:
@@ -244,7 +276,7 @@ def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
model_filename = f"ggml-{size.to_whisper_cpp_model_size()}.bin"
try:
- model_path = huggingface_hub.snapshot_download(
+ model_path = huggingface_hub.snapshot_download(
repo_id=repo_id,
allow_patterns=[model_filename],
local_files_only=True,
@@ -271,7 +303,8 @@ class HuggingfaceDownloadMonitor:
def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size: int):
self.model_root = model_root
self.progress = progress
- self.total_file_size = round(total_file_size * 1.1) # To keep dialog open even if it reports 100%
+ # To keep dialog open even if it reports 100%
+ self.total_file_size = round(total_file_size * 1.1)
self.incomplete_download_root = None
self.stop_event = threading.Event()
self.monitor_thread = None
@@ -279,8 +312,10 @@ class HuggingfaceDownloadMonitor:
def set_download_roots(self):
normalized_model_root = os.path.normpath(self.model_root)
- two_dirs_up = os.path.normpath(os.path.join(normalized_model_root, "..", ".."))
- self.incomplete_download_root = os.path.normpath(os.path.join(two_dirs_up, "blobs"))
+ two_dirs_up = os.path.normpath(
+ os.path.join(normalized_model_root, "..", ".."))
+ self.incomplete_download_root = os.path.normpath(
+ os.path.join(two_dirs_up, "blobs"))
def clean_tmp_files(self):
for filename in os.listdir(model_root_dir):
@@ -292,12 +327,14 @@ class HuggingfaceDownloadMonitor:
if model_root_dir is not None:
for filename in os.listdir(model_root_dir):
if filename.startswith("tmp"):
- file_size = os.path.getsize(os.path.join(model_root_dir, filename))
+ file_size = os.path.getsize(
+ os.path.join(model_root_dir, filename))
self.progress.emit((file_size, self.total_file_size))
for filename in os.listdir(self.incomplete_download_root):
if filename.endswith(".incomplete"):
- file_size = os.path.getsize(os.path.join(self.incomplete_download_root, filename))
+ file_size = os.path.getsize(os.path.join(
+ self.incomplete_download_root, filename))
self.progress.emit((file_size, self.total_file_size))
time.sleep(2)
@@ -332,7 +369,8 @@ def download_from_huggingface(
try:
model_root = huggingface_hub.snapshot_download(
repo_id,
- allow_patterns=allow_patterns[num_large_files:], # all, but largest
+ # all, but largest
+ allow_patterns=allow_patterns[num_large_files:],
cache_dir=model_root_dir,
etag_timeout=60
)
@@ -354,7 +392,8 @@ def download_from_huggingface(
except requests.exceptions.RequestException as e:
continue
- model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, largest_file_size)
+ model_download_monitor = HuggingfaceDownloadMonitor(
+ model_root, progress, largest_file_size)
model_download_monitor.start_monitoring()
try:
@@ -367,9 +406,7 @@ def download_from_huggingface(
except Exception as exc:
logging.exception(exc)
model_download_monitor.stop_monitoring()
- # Cleanup to prevent incomplete downloads errors
- if os.path.exists(model_root):
- shutil.rmtree(model_root)
+
return ""
model_download_monitor.stop_monitoring()
@@ -429,19 +466,22 @@ class ModelDownloader(QRunnable):
def __init__(self, model: TranscriptionModel, custom_model_url: Optional[str] = None):
super().__init__()
- self.is_coreml_supported = platform.system() == "Darwin" and platform.machine() == "arm64"
+ self.is_coreml_supported = platform.system(
+ ) == "Darwin" and platform.machine() == "arm64"
self.signals = self.Signals()
self.model = model
self.stopped = False
self.custom_model_url = custom_model_url
def run(self) -> None:
- logging.debug("Downloading model: %s, %s", self.model, self.model.hugging_face_model_id)
+ logging.debug("Downloading model: %s, %s", self.model,
+ self.model.hugging_face_model_id)
if self.model.model_type == ModelType.WHISPER_CPP:
if self.custom_model_url:
url = self.custom_model_url
- file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
+ file_path = get_whisper_cpp_file_path(
+ size=self.model.whisper_model_size)
return self.download_model_to_path(url=url, file_path=file_path)
repo_id = WHISPER_CPP_REPO_ID
@@ -458,9 +498,9 @@ class ModelDownloader(QRunnable):
num_large_files = 1
if self.is_coreml_supported:
whisper_cpp_model_files = [
- f"ggml-{model_name}.bin",
- f"ggml-{model_name}-encoder.mlmodelc.zip",
- "README.md"
+ f"ggml-{model_name}.bin",
+ f"ggml-{model_name}-encoder.mlmodelc.zip",
+ "README.md"
]
num_large_files = 2
@@ -476,12 +516,14 @@ class ModelDownloader(QRunnable):
os.path.join(model_path, f"ggml-{model_name}-encoder.mlmodelc.zip"), 'r') as zip_ref:
zip_ref.extractall(model_path)
- self.signals.finished.emit(os.path.join(model_path, f"ggml-{model_name}.bin"))
+ self.signals.finished.emit(os.path.join(
+ model_path, f"ggml-{model_name}.bin"))
return
if self.model.model_type == ModelType.WHISPER:
url = whisper._MODELS[self.model.whisper_model_size.value]
- file_path = get_whisper_file_path(size=self.model.whisper_model_size)
+ file_path = get_whisper_file_path(
+ size=self.model.whisper_model_size)
expected_sha256 = url.split("/")[-2]
return self.download_model_to_path(
url=url, file_path=file_path, expected_sha256=expected_sha256
@@ -526,16 +568,18 @@ class ModelDownloader(QRunnable):
downloaded = self.download_model(url, file_path, expected_sha256)
if downloaded:
self.signals.finished.emit(file_path)
- except requests.RequestException:
+ except requests.RequestException as e:
self.signals.error.emit(_("A connection error occurred"))
- if os.path.exists(file_path):
- os.remove(file_path)
+ if not self.stopped and "timeout" not in str(e).lower():
+ if os.path.exists(file_path):
+ os.remove(file_path)
logging.exception("")
except Exception as exc:
self.signals.error.emit(str(exc))
- if os.path.exists(file_path):
- os.remove(file_path)
- logging.exception(exc)
+ if not self.stopped:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ logging.exception(exc)
def download_model(
self, url: str, file_path: str, expected_sha256: Optional[str]
@@ -547,27 +591,118 @@ class ModelDownloader(QRunnable):
if os.path.exists(file_path) and not os.path.isfile(file_path):
raise RuntimeError(f"{file_path} exists and is not a regular file")
+ resume_from = 0
+ file_mode = "wb"
+
if os.path.isfile(file_path):
- if expected_sha256 is None:
- return True
+ file_size = os.path.getsize(file_path)
- model_bytes = open(file_path, "rb").read()
- model_sha256 = hashlib.sha256(model_bytes).hexdigest()
- if model_sha256 == expected_sha256:
- return True
+ if expected_sha256 is not None:
+ # Get the expected file size from URL
+ try:
+ head_response = requests.head(url, timeout=5, allow_redirects=True)
+ expected_size = int(head_response.headers.get("Content-Length", 0))
+
+ if expected_size > 0:
+ if file_size < expected_size:
+ resume_from = file_size
+ file_mode = "ab"
+ logging.debug(
+ f"File incomplete ({file_size}/{expected_size} bytes), resuming from byte {resume_from}"
+ )
+ elif file_size == expected_size:
+ # This means file size matches - verify SHA256 to confirm it is complete
+ try:
+ with open(file_path, "rb") as f:
+ model_bytes = f.read()
+ model_sha256 = hashlib.sha256(model_bytes).hexdigest()
+ if model_sha256 == expected_sha256:
+ logging.debug("Model already downloaded and verified")
+ return True
+ else:
+ warnings.warn(
+ f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file"
+ )
+ # File exists but it is wrong, delete it
+ os.remove(file_path)
+ except Exception as e:
+ logging.warning(f"Error checking existing file: {e}")
+ os.remove(file_path)
+ else:
+ # File is larger than expected - corrupted, delete it
+ warnings.warn(f"File size ({file_size}) exceeds expected size ({expected_size}), re-downloading")
+ os.remove(file_path)
+ else:
+ # Can't get expected size - use threshold approach
+ if file_size < 10 * 1024 * 1024:
+ resume_from = file_size
+ file_mode = "ab" # Append mode to resume
+ logging.debug(f"Resuming download from byte {resume_from}")
+ else:
+ # Large file - verify SHA256
+ try:
+ with open(file_path, "rb") as f:
+ model_bytes = f.read()
+ model_sha256 = hashlib.sha256(model_bytes).hexdigest()
+ if model_sha256 == expected_sha256:
+ logging.debug("Model already downloaded and verified")
+ return True
+ else:
+ warnings.warn("SHA256 mismatch, re-downloading")
+ os.remove(file_path)
+ except Exception as e:
+ logging.warning(f"Error verifying file: {e}")
+ os.remove(file_path)
+
+ except Exception as e:
+ # Can't get expected size - use threshold
+ logging.debug(f"Could not get expected file size: {e}, using threshold")
+ if file_size < 10 * 1024 * 1024:
+ resume_from = file_size
+ file_mode = "ab"
+ logging.debug(f"Resuming from byte {resume_from}")
else:
- warnings.warn(
- f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file"
- )
-
+ # No SHA256 to verify - just check file size
+ if file_size > 0:
+ resume_from = file_size
+ file_mode = "ab"
+ logging.debug(f"Resuming download from byte {resume_from}")
+
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
+ headers = {}
+ if resume_from > 0:
+ headers["Range"] = f"bytes={resume_from}-"
+
with requests.get(url, stream=True, timeout=15) as source, open(
- file_path, "wb"
+ file_path, file_mode
) as output:
source.raise_for_status()
- total_size = float(source.headers.get("Content-Length", 0))
- current = 0.0
+
+ if resume_from > 0:
+ if source.status_code == 206:
+ logging.debug(
+ f"Server supports resume, continuing from byte {resume_from}")
+ total_size = int(source.headers.get(
+ "Content-Range", "").split("/")[-1])
+ current = resume_from
+ self.signals.progress.emit((current, total_size))
+ elif source.status_code == 200:
+ logging.debug(
+ "Server doesn't support Range requests, starting from beginning")
+ # Truncate file and start over
+ output.close()
+ output = open(file_path, "wb")
+ total_size = float(source.headers.get("Content-Length", 0))
+ current = 0.0
+ resume_from = 0
+ else:
+ source.raise_for_status()
+
+ else:
+ total_size = float(source.headers.get("Content-Length", 0))
+ current = 0.0
+
self.signals.progress.emit((current, total_size))
for chunk in source.iter_content(chunk_size=8192):
if self.stopped:
From 252db3c3edaddbd487b0de9d77efa0ad30111356 Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Mon, 24 Nov 2025 21:59:21 +0200
Subject: [PATCH 10/73] Adding option to delete saved models and files on
uninstall (#1291)
---
installer.iss | 25 +++++++++++++++----------
1 file changed, 15 insertions(+), 10 deletions(-)
diff --git a/installer.iss b/installer.iss
index 69fa9b39..85b690d0 100644
--- a/installer.iss
+++ b/installer.iss
@@ -51,16 +51,6 @@ Filename: "{app}\{#AppExeName}"; Description: "{cm:LaunchProgram,{#StringChange(
Root: HKCU; Subkey: "{#AppRegKey}"
[Code]
-procedure CurUninstallStepChanged(CurUninstallStep: TUninstallStep);
-begin
- if CurUninstallStep = usPostUninstall then
- begin
- if RegKeyExists(HKEY_CURRENT_USER, '{#AppRegKey}') then
- if MsgBox('Do you want to delete Buzz settings?', mbConfirmation, MB_YESNO) = IDYES
- then
- RegDeleteKeyIncludingSubkeys(HKEY_CURRENT_USER, '{#AppRegKey}');
- end;
-end;
procedure DeleteFileOrFolder(FilePath: string);
begin
if FileExists(FilePath) then
@@ -73,6 +63,21 @@ begin
end;
end;
+procedure CurUninstallStepChanged(CurUninstallStep: TUninstallStep);
+begin
+ if CurUninstallStep = usPostUninstall then
+ begin
+ if RegKeyExists(HKEY_CURRENT_USER, '{#AppRegKey}') then
+ if MsgBox('Do you want to delete Buzz settings and saved files?', mbConfirmation, MB_YESNO) = IDYES
+ then
+ begin
+ RegDeleteKeyIncludingSubkeys(HKEY_CURRENT_USER, '{#AppRegKey}');
+ // Remove model and cache directories
+ DeleteFileOrFolder(ExpandConstant('{localappdata}\Buzz'));
+ end;
+ end;
+end;
+
procedure CurStepChanged(CurStep: TSetupStep);
begin
if CurStep = ssInstall then
From cabbd487f94b6f2516df58d7d03ad9e0afaa218f Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Fri, 28 Nov 2025 21:30:36 +0200
Subject: [PATCH 11/73] Improvements (#1296)
---
Makefile | 2 +-
buzz/__version__.py | 2 +-
buzz/locale/ca_ES/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/da_DK/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/de_DE/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/en_US/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/es_ES/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/it_IT/LC_MESSAGES/buzz.po | 54 +++++++++++++++++----------
buzz/locale/ja_JP/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/lv_LV/LC_MESSAGES/buzz.po | 26 +++++++------
buzz/locale/nl/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/pl_PL/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/pt_BR/LC_MESSAGES/buzz.po | 32 +++++++++-------
buzz/locale/uk_UA/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/zh_CN/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/locale/zh_TW/LC_MESSAGES/buzz.po | 24 +++++++-----
buzz/transcriber/whisper_cpp.py | 9 +++--
buzz/widgets/about_dialog.py | 10 +++++
docs/docs/faq.md | 2 +
pyproject.toml | 2 +-
snap/snapcraft.yaml | 2 +-
whisper.cpp | 2 +-
22 files changed, 244 insertions(+), 163 deletions(-)
diff --git a/Makefile b/Makefile
index 9b4050ef..af2aa9a1 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,4 @@
-version := 1.3.3
+version := 1.3.4
mac_app_path := ./dist/Buzz.app
mac_zip_path := ./dist/Buzz-${version}-mac.zip
diff --git a/buzz/__version__.py b/buzz/__version__.py
index e371c8ac..4a16f216 100644
--- a/buzz/__version__.py
+++ b/buzz/__version__.py
@@ -1 +1 @@
-VERSION = "1.3.3"
+VERSION = "1.3.4"
diff --git a/buzz/locale/ca_ES/LC_MESSAGES/buzz.po b/buzz/locale/ca_ES/LC_MESSAGES/buzz.po
index 1450d9c0..e88359f6 100644
--- a/buzz/locale/ca_ES/LC_MESSAGES/buzz.po
+++ b/buzz/locale/ca_ES/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: buzz\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2025-10-17 07:59+0200\n"
"Last-Translator: Éric Duarte \n"
"Language-Team: Catalan \n"
@@ -308,8 +308,8 @@ msgid "Download failed"
msgstr "Descàrrega fallida"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Error"
@@ -513,11 +513,15 @@ msgstr ""
"Comproveu els vostres dispositius d'àudio o els registres de l'aplicació per "
"a més informació."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Comprova si hi ha actualitzacions"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "Estàs al dia!"
@@ -764,14 +768,14 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "No s'ha pogut desar la clau OpenAI API a l'anell de claus"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
"El servidor Whisper no s'ha pogut iniciar. Consulteu els registres per "
"obtenir més informació."
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1145,15 +1149,15 @@ msgstr "Sundanès"
msgid "Cantonese"
msgstr "Cantonès"
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "S'ha produït un error de connexió"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr "Començant Whisper.cpp..."
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "Cancel·la la transcripció"
diff --git a/buzz/locale/da_DK/LC_MESSAGES/buzz.po b/buzz/locale/da_DK/LC_MESSAGES/buzz.po
index 7328ba15..7d356c67 100644
--- a/buzz/locale/da_DK/LC_MESSAGES/buzz.po
+++ b/buzz/locale/da_DK/LC_MESSAGES/buzz.po
@@ -2,7 +2,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: \n"
"Last-Translator: Ole Guldberg2 \n"
"Language-Team: \n"
@@ -307,8 +307,8 @@ msgid "Download failed"
msgstr "Download mislykkedes"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Fejl"
@@ -511,11 +511,15 @@ msgstr ""
"Tjek venligst dine audioenheder eller tjek applikationens logs for "
"mereinformation."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Tjek for opdateringer"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "Du er opdateret!"
@@ -760,12 +764,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "Kan ikke gemme OpenAI API-nøgle i nøgleringen"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1137,15 +1141,15 @@ msgstr ""
msgid "Cantonese"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "Der er opstået en forbindelsesfejl"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "Afbryd transkription"
diff --git a/buzz/locale/de_DE/LC_MESSAGES/buzz.po b/buzz/locale/de_DE/LC_MESSAGES/buzz.po
index de802203..14ebc504 100644
--- a/buzz/locale/de_DE/LC_MESSAGES/buzz.po
+++ b/buzz/locale/de_DE/LC_MESSAGES/buzz.po
@@ -6,7 +6,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2025-03-05 14:41+0100\n"
"Last-Translator: \n"
"Language-Team: \n"
@@ -307,8 +307,8 @@ msgid "Download failed"
msgstr "Der Download ist fehlgeschlagen"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Fehler"
@@ -511,11 +511,15 @@ msgstr ""
"Bitte überprüfen Sie Ihre Audiogeräte oder prüfen Sie die "
"Anwendungsprotokolle für weitere Informationen."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Nach Updates suchen"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "Sie sind auf dem Laufenden!"
@@ -761,12 +765,12 @@ msgstr ""
"Der OpenAI-API-Schlüssel kann nicht im Schlüsselbund gespeichert werden"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1138,15 +1142,15 @@ msgstr "Sundanesisch"
msgid "Cantonese"
msgstr "Kantonesisch"
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "Ein Verbindungsfehler ist aufgetreten"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "Transkription abbrechen"
diff --git a/buzz/locale/en_US/LC_MESSAGES/buzz.po b/buzz/locale/en_US/LC_MESSAGES/buzz.po
index d7fb3dc7..87f47cea 100644
--- a/buzz/locale/en_US/LC_MESSAGES/buzz.po
+++ b/buzz/locale/en_US/LC_MESSAGES/buzz.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language-Team: LANGUAGE \n"
@@ -299,8 +299,8 @@ msgid "Download failed"
msgstr ""
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr ""
@@ -499,11 +499,15 @@ msgid ""
"information."
msgstr ""
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr ""
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr ""
@@ -742,12 +746,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1118,15 +1122,15 @@ msgstr ""
msgid "Cantonese"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
msgid "Starting transcription..."
msgstr ""
diff --git a/buzz/locale/es_ES/LC_MESSAGES/buzz.po b/buzz/locale/es_ES/LC_MESSAGES/buzz.po
index f7e2d9e3..133209e1 100644
--- a/buzz/locale/es_ES/LC_MESSAGES/buzz.po
+++ b/buzz/locale/es_ES/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2025-09-08 12:43+0200\n"
"Last-Translator: Éric Duarte \n"
"Language-Team: \n"
@@ -314,8 +314,8 @@ msgid "Download failed"
msgstr "Descarga fallida"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Error"
@@ -546,12 +546,16 @@ msgstr ""
"aplicación para obtener más información."
# automatic translation
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Buscar actualizaciones"
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
# automatic translation
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "¡Estás al día!"
@@ -810,14 +814,14 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "No se puede guardar la clave de la API de OpenAI en el llavero"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
"El servidor Whisper no se pudo iniciar. Consulta los registros para obtener "
"más detalles."
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1192,16 +1196,16 @@ msgstr "Sundanés"
msgid "Cantonese"
msgstr "Cantonés"
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "Se ha producido un error de conexión"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr "Iniciando Whisper.cpp..."
# automatic translation
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "Cancelar transcripción"
diff --git a/buzz/locale/it_IT/LC_MESSAGES/buzz.po b/buzz/locale/it_IT/LC_MESSAGES/buzz.po
index f215fb66..127a3b0c 100644
--- a/buzz/locale/it_IT/LC_MESSAGES/buzz.po
+++ b/buzz/locale/it_IT/LC_MESSAGES/buzz.po
@@ -6,7 +6,7 @@ msgid ""
msgstr ""
"Project-Id-Version: buzz\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2025-11-09 20:22+0200\n"
"Language-Team: (Italiano) Albano Battistella \n"
"Language: it_IT\n"
@@ -106,6 +106,11 @@ msgstr "Lettone"
msgid "Polish"
msgstr "Polacco"
+#: buzz/widgets/preferences_dialog/general_preferences_widget.py:45
+#, fuzzy
+msgid "Portuguese (Brazil)"
+msgstr "Portoghese"
+
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:46
#: buzz/transcriber/transcriber.py:59
msgid "Ukrainian"
@@ -171,7 +176,9 @@ msgstr "Utilizza solo la CPU e disattiva l'accelerazione GPU"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:186
msgid "Set this if larger models do not fit your GPU memory and Buzz crashes"
-msgstr "Imposta questa opzione se i modelli più grandi non si adattano alla memoria della tua GPU e Buzz si blocca"
+msgstr ""
+"Imposta questa opzione se i modelli più grandi non si adattano alla memoria "
+"della tua GPU e Buzz si blocca"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:188
msgid "Disable GPU"
@@ -301,8 +308,8 @@ msgid "Download failed"
msgstr "Download non riuscito"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Errore"
@@ -387,8 +394,8 @@ msgid ""
"Enter instructions for AI on how to translate, for example 'Please translate "
"each text sent to you from English to Spanish.'"
msgstr ""
-"Inserisci le istruzioni per l'IA su come tradurre, ad esempio 'Per favore, traduci "
-"ogni testo che ti viene inviato dall'inglese allo spagnolo.'"
+"Inserisci le istruzioni per l'IA su come tradurre, ad esempio 'Per favore, "
+"traduci ogni testo che ti viene inviato dall'inglese allo spagnolo.'"
#: buzz/widgets/transcriber/advanced_settings_dialog.py:92
msgid "Instructions for AI:"
@@ -507,11 +514,15 @@ msgstr ""
"Controlla i tuoi dispositivi audio o i registri dell'applicazione per "
"maggiori informazioni."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Controlla gli aggiornamenti"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "Il programma è aggiornato!"
@@ -595,7 +606,8 @@ msgstr "Ciclo di segmento"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
msgid "Enable/disable looping when clicking on transcript segments"
-msgstr "Abilita/disabilita il loop quando si fa clic sui segmenti della trascrizione"
+msgstr ""
+"Abilita/disabilita il loop quando si fa clic sui segmenti della trascrizione"
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
msgid "Follow Audio"
@@ -606,8 +618,9 @@ msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-"Abilita/disabilita la lettura della posizione audio corrente nella trascrizione. Quando "
-"abilitato, scorre automaticamente fino al testo corrente."
+"Abilita/disabilita la lettura della posizione audio corrente nella "
+"trascrizione. Quando abilitato, scorre automaticamente fino al testo "
+"corrente."
#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
msgid "Scroll to Current"
@@ -758,20 +771,21 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "Impossibile salvare la chiave API OpenAI nel portachiavi"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
-msgstr "Impossibile avviare il server Whisper. Controllare i log per i dettagli."
+msgstr ""
+"Impossibile avviare il server Whisper. Controllare i log per i dettagli."
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
"variable."
msgstr ""
-"Impossibile avviare il server Whisper a causa di memoria insufficiente. Riprovare "
-"con un modello più piccolo. Per forzare la modalità CPU, utilizzare la variabile d'ambiente "
-"BUZZ_FORCE_CPU=TRUE"
+"Impossibile avviare il server Whisper a causa di memoria insufficiente. "
+"Riprovare con un modello più piccolo. Per forzare la modalità CPU, "
+"utilizzare la variabile d'ambiente BUZZ_FORCE_CPU=TRUE"
#: buzz/transcriber/transcriber.py:24
msgid "Translate to English"
@@ -1137,15 +1151,15 @@ msgstr "Sundanese"
msgid "Cantonese"
msgstr "Cantonese"
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "Si è verificato un errore di connessione"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr "Avvio di Whisper.cpp..."
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
msgid "Starting transcription..."
msgstr "Inizio trascrizione..."
diff --git a/buzz/locale/ja_JP/LC_MESSAGES/buzz.po b/buzz/locale/ja_JP/LC_MESSAGES/buzz.po
index efbab4d7..b5ec2b11 100644
--- a/buzz/locale/ja_JP/LC_MESSAGES/buzz.po
+++ b/buzz/locale/ja_JP/LC_MESSAGES/buzz.po
@@ -2,7 +2,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: \n"
"Last-Translator: nunawa <71294849+nunawa@users.noreply.github.com>\n"
"Language-Team: \n"
@@ -303,8 +303,8 @@ msgid "Download failed"
msgstr "ダウンロード失敗"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "エラー"
@@ -507,11 +507,15 @@ msgstr ""
"オーディオデバイスを確認するか、詳細をアプリケーションのログで確認してくださ"
"い。"
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "アップデートを確認する"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "最新の状態です!"
@@ -755,12 +759,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "OpenAI API キーをkeyringに保存できません"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1132,15 +1136,15 @@ msgstr ""
msgid "Cantonese"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "接続エラーが発生しました"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "文字起こしをキャンセルする"
diff --git a/buzz/locale/lv_LV/LC_MESSAGES/buzz.po b/buzz/locale/lv_LV/LC_MESSAGES/buzz.po
index ae700ba1..18f799f5 100644
--- a/buzz/locale/lv_LV/LC_MESSAGES/buzz.po
+++ b/buzz/locale/lv_LV/LC_MESSAGES/buzz.po
@@ -8,8 +8,8 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
-"PO-Revision-Date: 2025-10-12 19:11+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"PO-Revision-Date: 2025-11-28 16:50+0200\n"
"Last-Translator: \n"
"Language-Team: \n"
"Language: lv_LV\n"
@@ -311,8 +311,8 @@ msgid "Download failed"
msgstr "Lejupielāde neizdevās"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Kļūda"
@@ -517,11 +517,15 @@ msgstr ""
"Lūdzu pārbaudiet savas audio ierīces vai pārbaudiet lietotnes ziņojumu "
"žurnālus, lai iegūtu papildu informāciju."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Pārbaudīt atjauninājumus"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr "Parādīt sistēmas žurnālu"
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "Jums ir jaunākā versija!"
@@ -766,14 +770,14 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "Neizdevās saglabāt OpenAI API atslēgu atslēgu saišķī"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
"Whisper serverim neizdevās ieslēgties. Lūdzu pārbaudiet lietotnes žurnāla "
"ierakstus."
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1147,15 +1151,15 @@ msgstr "Sundāņu"
msgid "Cantonese"
msgstr "Kantonas"
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "Notika savienojuma kļūda"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr "Palaiž Whisper.cpp..."
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
msgid "Starting transcription..."
msgstr "Sāk atpazīšanu..."
diff --git a/buzz/locale/nl/LC_MESSAGES/buzz.po b/buzz/locale/nl/LC_MESSAGES/buzz.po
index b311c175..2e21acc2 100644
--- a/buzz/locale/nl/LC_MESSAGES/buzz.po
+++ b/buzz/locale/nl/LC_MESSAGES/buzz.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2025-03-20 18:30+0100\n"
"Last-Translator: Heimen Stoffels \n"
"Language-Team: none\n"
@@ -309,8 +309,8 @@ msgid "Download failed"
msgstr "Het downloaden is mislukt"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Foutmelding"
@@ -511,11 +511,15 @@ msgid ""
"information."
msgstr "Controleer uw geluidsapparatuur of het programmalogboek."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Controleren op updates"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "De software is actueel!"
@@ -759,12 +763,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "De OpenAI-api-sleutel kan niet worden bewaard in de sleutelbos"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1136,15 +1140,15 @@ msgstr "Soedanees"
msgid "Cantonese"
msgstr "Kantonees"
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "Er is een verbindingsfout opgetreden"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "Transcriptie wissen"
diff --git a/buzz/locale/pl_PL/LC_MESSAGES/buzz.po b/buzz/locale/pl_PL/LC_MESSAGES/buzz.po
index 09c61f94..2d452294 100644
--- a/buzz/locale/pl_PL/LC_MESSAGES/buzz.po
+++ b/buzz/locale/pl_PL/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2024-03-17 20:50+0200\n"
"Last-Translator: \n"
"Language-Team: \n"
@@ -310,8 +310,8 @@ msgid "Download failed"
msgstr "Pobrany"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Błąd"
@@ -519,11 +519,15 @@ msgstr ""
"Sprawdź urządzenia audio lub przejrzyj logi aplikacji, by uzyskać więcej "
"informacji."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Sprawdź aktualizacje"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "Posiadasz najnowszą wersję!"
@@ -769,12 +773,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1147,15 +1151,15 @@ msgstr ""
msgid "Cantonese"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "Anuluj transkrypcję"
diff --git a/buzz/locale/pt_BR/LC_MESSAGES/buzz.po b/buzz/locale/pt_BR/LC_MESSAGES/buzz.po
index 6e002ac8..25165acd 100644
--- a/buzz/locale/pt_BR/LC_MESSAGES/buzz.po
+++ b/buzz/locale/pt_BR/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: Buzz\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2025-11-01 17:43-0300\n"
"Last-Translator: Paulo Schopf \n"
"Language-Team: none\n"
@@ -176,7 +176,8 @@ msgstr "Usar somente a CPU e desabilitar aceleração por GPU"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:186
msgid "Set this if larger models do not fit your GPU memory and Buzz crashes"
-msgstr "Marque isso se modelos maiores não couberem na memória da GPU e o Buzz travar"
+msgstr ""
+"Marque isso se modelos maiores não couberem na memória da GPU e o Buzz travar"
#: buzz/widgets/preferences_dialog/general_preferences_widget.py:188
msgid "Disable GPU"
@@ -306,8 +307,8 @@ msgid "Download failed"
msgstr "Falha ao baixar"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Erro"
@@ -390,8 +391,8 @@ msgid ""
"Enter instructions for AI on how to translate, for example 'Please translate "
"each text sent to you from English to Spanish.'"
msgstr ""
-"Instrua a IA sobre como traduzir, por exemplo: \"Por favor, "
-"traduza cada texto enviado a você do Inglês para o Português\"."
+"Instrua a IA sobre como traduzir, por exemplo: \"Por favor, traduza cada "
+"texto enviado a você do Inglês para o Português\"."
#: buzz/widgets/transcriber/advanced_settings_dialog.py:92
msgid "Instructions for AI:"
@@ -510,11 +511,15 @@ msgstr ""
"Verifique seus dispositivos de áudio ou os logs do aplicativo para mais "
"informações."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Verificar atualizações"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "Você está atualizado!"
@@ -761,12 +766,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "Não foi possível salvar a chave da API OpenAI no cofre de chaves"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr "Falha ao iniciar o servidor Whisper. Verifique os logs."
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1140,15 +1145,15 @@ msgstr "Sundanês"
msgid "Cantonese"
msgstr "Cantonês"
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "Ocorreu um erro de conexão"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr "Iniciando Whisper.cpp..."
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "Iniciando transcrição..."
@@ -1233,5 +1238,6 @@ msgstr "Acrescentar e corrigir"
#~ msgid "Undo"
#~ msgstr "Desfazer"
+
#~ msgid "Redo"
#~ msgstr "Refazer"
diff --git a/buzz/locale/uk_UA/LC_MESSAGES/buzz.po b/buzz/locale/uk_UA/LC_MESSAGES/buzz.po
index f0c8d508..f45a5184 100644
--- a/buzz/locale/uk_UA/LC_MESSAGES/buzz.po
+++ b/buzz/locale/uk_UA/LC_MESSAGES/buzz.po
@@ -2,7 +2,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: \n"
"Last-Translator: Yevhen Popok \n"
"Language-Team: \n"
@@ -305,8 +305,8 @@ msgid "Download failed"
msgstr "Невдале завантаження"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "Помилка"
@@ -509,11 +509,15 @@ msgstr ""
"Будь ласка, перевірте свої аудіопристрої або пошукайте додаткову інформацію "
"в звітах програми."
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "Перевірити оновлення"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "У вас актуальна версія!"
@@ -756,12 +760,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "Не вдається додати до звʼязки ключів API-ключ OpenAI"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1133,15 +1137,15 @@ msgstr ""
msgid "Cantonese"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "Виникла помилка зʼєднання"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "Скасувати транскрипцію"
diff --git a/buzz/locale/zh_CN/LC_MESSAGES/buzz.po b/buzz/locale/zh_CN/LC_MESSAGES/buzz.po
index 4ea086de..0e9154a2 100644
--- a/buzz/locale/zh_CN/LC_MESSAGES/buzz.po
+++ b/buzz/locale/zh_CN/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2023-05-01 15:45+0800\n"
"Last-Translator: \n"
"Language-Team: lamb \n"
@@ -313,8 +313,8 @@ msgid "Download failed"
msgstr "下载模型失败"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr "错误"
@@ -520,11 +520,15 @@ msgid ""
"information."
msgstr "请检查您的音频设备或检查应用程序日志以获取更多信息。"
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "检查更新"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "已经是最新版本"
@@ -769,12 +773,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr "无法将OpenAI API密钥保存到密钥串"
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1147,15 +1151,15 @@ msgstr ""
msgid "Cantonese"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr "连接发生错误"
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "取消识别"
diff --git a/buzz/locale/zh_TW/LC_MESSAGES/buzz.po b/buzz/locale/zh_TW/LC_MESSAGES/buzz.po
index bc0f7679..ed67c2c8 100644
--- a/buzz/locale/zh_TW/LC_MESSAGES/buzz.po
+++ b/buzz/locale/zh_TW/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-10-12 19:10+0300\n"
+"POT-Creation-Date: 2025-11-28 16:49+0200\n"
"PO-Revision-Date: 2023-05-01 15:45+0800\n"
"Last-Translator: \n"
"Language-Team: Lamb\n"
@@ -308,8 +308,8 @@ msgid "Download failed"
msgstr "下載模型"
#: buzz/widgets/preferences_dialog/models_preferences_widget.py:275
-#: buzz/widgets/main_window.py:295 buzz/model_loader.py:497
-#: buzz/model_loader.py:511
+#: buzz/widgets/main_window.py:295 buzz/model_loader.py:539
+#: buzz/model_loader.py:553
msgid "Error"
msgstr ""
@@ -515,11 +515,15 @@ msgid ""
"information."
msgstr "請檢查您的音頻設備或檢查應用程序日誌以獲取更多信息。"
-#: buzz/widgets/about_dialog.py:80
+#: buzz/widgets/about_dialog.py:81
msgid "Check for updates"
msgstr "檢查更新"
-#: buzz/widgets/about_dialog.py:109
+#: buzz/widgets/about_dialog.py:84
+msgid "Show logs"
+msgstr ""
+
+#: buzz/widgets/about_dialog.py:118
msgid "You're up to date!"
msgstr "你是最新的!"
@@ -763,12 +767,12 @@ msgid "Unable to save OpenAI API key to keyring"
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:57
-#: buzz/transcriber/recording_transcriber.py:394
+#: buzz/transcriber/recording_transcriber.py:397
msgid "Whisper server failed to start. Check logs for details."
msgstr ""
#: buzz/transcriber/local_whisper_cpp_server_transcriber.py:60
-#: buzz/transcriber/recording_transcriber.py:398
+#: buzz/transcriber/recording_transcriber.py:401
msgid ""
"Whisper server failed to start due to insufficient memory. Please try again "
"with a smaller model. To force CPU mode use BUZZ_FORCE_CPU=TRUE environment "
@@ -1141,15 +1145,15 @@ msgstr ""
msgid "Cantonese"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:223 buzz/model_loader.py:530
+#: buzz/transcriber/recording_transcriber.py:224 buzz/model_loader.py:572
msgid "A connection error occurred"
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:332
+#: buzz/transcriber/recording_transcriber.py:333
msgid "Starting Whisper.cpp..."
msgstr ""
-#: buzz/transcriber/recording_transcriber.py:385
+#: buzz/transcriber/recording_transcriber.py:388
#, fuzzy
msgid "Starting transcription..."
msgstr "取消錄製"
diff --git a/buzz/transcriber/whisper_cpp.py b/buzz/transcriber/whisper_cpp.py
index 201ac450..8f12aec8 100644
--- a/buzz/transcriber/whisper_cpp.py
+++ b/buzz/transcriber/whisper_cpp.py
@@ -109,12 +109,13 @@ class WhisperCpp:
# Add translate flag if needed
if task.transcription_options.task == Task.TRANSLATE:
- cmd.append("--translate")
+ cmd.extend(["--translate"])
# Force CPU if specified
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
if force_cpu != "false" or not IS_VULKAN_SUPPORTED:
- cmd.append("--no-gpu")
+ cmd.extend(["--no-gpu"])
+ cmd.extend(["-t", str(os.getenv("BUZZ_WHISPERCPP_N_THREADS", (os.cpu_count() or 8) // 2))])
print(f"Running Whisper CLI: {' '.join(cmd)}")
@@ -125,7 +126,7 @@ class WhisperCpp:
si.wShowWindow = subprocess.SW_HIDE
process = subprocess.Popen(
cmd,
- stdout=subprocess.PIPE,
+ stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
text=True,
startupinfo=si,
@@ -135,7 +136,7 @@ class WhisperCpp:
else:
process = subprocess.Popen(
cmd,
- stdout=subprocess.PIPE,
+ stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
text=True,
)
diff --git a/buzz/widgets/about_dialog.py b/buzz/widgets/about_dialog.py
index d4b3abbb..5c6b6757 100644
--- a/buzz/widgets/about_dialog.py
+++ b/buzz/widgets/about_dialog.py
@@ -1,5 +1,6 @@
import json
from typing import Optional
+from platformdirs import user_log_dir
from PyQt6 import QtGui
from PyQt6.QtCore import Qt, QUrl
@@ -80,6 +81,9 @@ class AboutDialog(QDialog):
self.check_updates_button = QPushButton(_("Check for updates"), self)
self.check_updates_button.clicked.connect(self.on_click_check_for_updates)
+ self.show_logs_button = QPushButton(_("Show logs"), self)
+ self.show_logs_button.clicked.connect(self.on_click_show_logs)
+
button_box = QDialogButtonBox(
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Close), self
)
@@ -90,15 +94,21 @@ class AboutDialog(QDialog):
layout.addWidget(buzz_label)
layout.addWidget(version_label)
layout.addWidget(self.check_updates_button)
+ layout.addWidget(self.show_logs_button)
layout.addWidget(button_box)
self.setLayout(layout)
+ self.setMinimumWidth(350)
def on_click_check_for_updates(self):
url = QUrl(self.GITHUB_API_LATEST_RELEASE_URL)
self.network_access_manager.get(QNetworkRequest(url))
self.check_updates_button.setDisabled(True)
+ def on_click_show_logs(self):
+ log_dir = user_log_dir(appname="Buzz")
+ QDesktopServices.openUrl(QUrl.fromLocalFile(log_dir))
+
def on_latest_release_reply(self, reply: QNetworkReply):
if reply.error() == QNetworkReply.NetworkError.NoError:
response = json.loads(reply.readAll().data())
diff --git a/docs/docs/faq.md b/docs/docs/faq.md
index 4de7f377..ab47a824 100644
--- a/docs/docs/faq.md
+++ b/docs/docs/faq.md
@@ -13,6 +13,8 @@ The models are stored:
Paste the location in your file manager to access the models.
+Since Version `1.3.4`, to get to the logs folder go to `Help -> About Buzz` and click on `Show logs` button.
+
### 2. What can I try if the transcription runs too slowly?
Speech recognition requires large amount of computation, so one option is to try using a lower Whisper model size or using a Whisper.cpp model to run speech recognition of your computer. If you have access to a computer with GPU that has at least 6GB of VRAM you can try using the Faster Whisper model.
diff --git a/pyproject.toml b/pyproject.toml
index 01894149..094ffccd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "buzz-captions"
-version = "1.3.3"
+version = "1.3.4"
description = ""
authors = [{ name = "Chidi Williams", email = "williamschidi1@gmail.com" }]
requires-python = ">=3.12,<3.13"
diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml
index 346ff7e4..53cc91db 100644
--- a/snap/snapcraft.yaml
+++ b/snap/snapcraft.yaml
@@ -153,7 +153,7 @@ apps:
desktop: usr/share/applications/buzz.desktop
environment:
PATH: $SNAP/usr/bin:$SNAP/bin:$PATH
- LD_LIBRARY_PATH: $SNAP/lib/python3.12/site-packages/nvidia/cudnn/lib:$SNAP/lib/python3.12/site-packages/PyQt6:$SNAP/lib/python3.12/site-packages/PyQt6/Qt6/lib:$SNAP/usr/lib/$CRAFT_ARCH_TRIPLET_BUILD_FOR/lapack:$SNAP/usr/lib/$CRAFT_ARCH_TRIPLET_BUILD_FOR/blas:$SNAP/usr/lib/$CRAFT_ARCH_TRIPLET_BUILD_FOR/libproxy:$SNAP:$LD_LIBRARY_PATH
+ LD_LIBRARY_PATH: $SNAP/lib/python3.12/site-packages/nvidia/cudnn/lib:$SNAP/lib/python3.12/site-packages/PyQt6:$SNAP/lib/python3.12/site-packages/PyQt6/Qt6/lib:$SNAP/usr/lib/$CRAFT_ARCH_TRIPLET_BUILD_FOR/lapack:$SNAP/usr/lib/$CRAFT_ARCH_TRIPLET_BUILD_FOR/blas:$SNAP/usr/lib/$CRAFT_ARCH_TRIPLET_BUILD_FOR/oss4-libsalsa:$SNAP/usr/lib/$CRAFT_ARCH_TRIPLET_BUILD_FOR/libproxy:$SNAP:$LD_LIBRARY_PATH
PYTHONPATH: $SNAP:$SNAP/lib/python3.12/site-packages/PyQt6:$SNAP/lib/python3.12/site-packages/PyQt6/Qt6/lib:$SNAP/usr/lib/python3/dist-packages:$SNAP/usr/lib/python3.12/site-packages:$SNAP/usr/local/lib/python3.12/dist-packages:$SNAP/usr/lib/python3.12/dist-packages:$PYTHONPATH
QT_MEDIA_BACKEND: gstreamer
PULSE_LATENCY_MSEC: "30"
diff --git a/whisper.cpp b/whisper.cpp
index a8d002cf..4979e04f 160000
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -1 +1 @@
-Subproject commit a8d002cfd879315632a579e73f0148d06959de36
+Subproject commit 4979e04f5dcaccb36057e059bbaed8a2f5288315
From 73376a63ac660b94f9f94175910f224ff0ed0d2b Mon Sep 17 00:00:00 2001
From: Raivis Dejus
Date: Tue, 2 Dec 2025 21:39:24 +0200
Subject: [PATCH 12/73] Add speaker identification2 (#1290)
Co-authored-by: David Olowomeye <100958002+greatdaveo@users.noreply.github.com>
---
.coveragerc | 5 +-
.github/workflows/ci.yml | 4 +-
.github/workflows/snapcraft.yml | 13 +
.gitignore | 3 +-
.gitmodules | 12 +
Buzz.spec | 8 +-
CONTRIBUTING.md | 6 +-
Makefile | 7 +-
README.md | 2 +-
buzz/__version__.py | 2 +-
buzz/assets/speaker-identification.svg | 14 +
buzz/buzz.py | 11 +
buzz/file_transcriber_queue_worker.py | 2 +-
buzz/locale/ca_ES/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/da_DK/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/de_DE/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/en_US/LC_MESSAGES/buzz.po | 143 +-
buzz/locale/es_ES/LC_MESSAGES/buzz.po | 147 +-
buzz/locale/it_IT/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/ja_JP/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/lv_LV/LC_MESSAGES/buzz.po | 146 +-
buzz/locale/nl/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/pl_PL/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/pt_BR/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/uk_UA/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/zh_CN/LC_MESSAGES/buzz.po | 145 +-
buzz/locale/zh_TW/LC_MESSAGES/buzz.po | 145 +-
buzz/transcriber/whisper_cpp.py | 5 +-
buzz/widgets/icon.py | 4 +
.../speaker_identification_widget.py | 504 ++++
.../transcription_resizer_widget.py | 10 +-
.../transcription_viewer_widget.py | 238 +-
ctc_forced_aligner | 1 +
deepmultilingualpunctuation | 1 +
demucs/.github/ISSUE_TEMPLATE/bug.md | 33 +
demucs/.github/ISSUE_TEMPLATE/question.md | 10 +
demucs/.github/workflows/linter.yml | 36 +
demucs/.github/workflows/tests.yml | 36 +
demucs/.gitignore | 17 +
demucs/CODE_OF_CONDUCT.md | 76 +
demucs/CONTRIBUTING.md | 23 +
demucs/Demucs.ipynb | 153 ++
demucs/LICENSE | 21 +
demucs/MANIFEST.in | 13 +
demucs/Makefile | 36 +
demucs/README.md | 319 +++
demucs/Readme.md | 1 -
demucs/conf/config.yaml | 304 +++
demucs/conf/dset/aetl.yaml | 19 +
demucs/conf/dset/auto_extra_test.yaml | 18 +
demucs/conf/dset/auto_mus.yaml | 20 +
demucs/conf/dset/extra44.yaml | 8 +
demucs/conf/dset/extra_mmi_goodclean.yaml | 12 +
demucs/conf/dset/extra_test.yaml | 12 +
demucs/conf/dset/musdb44.yaml | 5 +
demucs/conf/dset/sdx23_bleeding.yaml | 10 +
demucs/conf/dset/sdx23_labelnoise.yaml | 10 +
demucs/conf/svd/base.yaml | 14 +
demucs/conf/svd/base2.yaml | 14 +
demucs/conf/svd/default.yaml | 1 +
demucs/conf/variant/default.yaml | 1 +
demucs/conf/variant/example.yaml | 5 +
demucs/conf/variant/finetune.yaml | 19 +
demucs/demucs.png | Bin 0 -> 339294 bytes
demucs/{ => demucs}/__init__.py | 0
demucs/{ => demucs}/__main__.py | 0
demucs/{ => demucs}/api.py | 0
demucs/{ => demucs}/apply.py | 0
demucs/{ => demucs}/audio.py | 0
demucs/{ => demucs}/audio_legacy.py | 0
demucs/{ => demucs}/augment.py | 0
demucs/{ => demucs}/demucs.py | 0
demucs/{ => demucs}/distrib.py | 0
demucs/{ => demucs}/ema.py | 0
demucs/{ => demucs}/evaluate.py | 0
demucs/{ => demucs}/grids/__init__.py | 0
demucs/{ => demucs}/grids/_explorers.py | 0
demucs/{ => demucs}/grids/mdx.py | 0
demucs/{ => demucs}/grids/mdx_extra.py | 0
demucs/{ => demucs}/grids/mdx_refine.py | 0
demucs/{ => demucs}/grids/mmi.py | 0
demucs/{ => demucs}/grids/mmi_ft.py | 0
demucs/{ => demucs}/grids/repro.py | 0
demucs/{ => demucs}/grids/repro_ft.py | 0
demucs/{ => demucs}/grids/sdx23.py | 0
demucs/{ => demucs}/hdemucs.py | 0
demucs/{ => demucs}/htdemucs.py | 0
demucs/{ => demucs}/pretrained.py | 0
demucs/{ => demucs}/py.typed | 0
demucs/{ => demucs}/remote/files.txt | 0
demucs/{ => demucs}/remote/hdemucs_mmi.yaml | 0
demucs/{ => demucs}/remote/htdemucs.yaml | 0
demucs/{ => demucs}/remote/htdemucs_6s.yaml | 0
demucs/{ => demucs}/remote/htdemucs_ft.yaml | 0
demucs/{ => demucs}/remote/mdx.yaml | 0
demucs/{ => demucs}/remote/mdx_extra.yaml | 0
demucs/{ => demucs}/remote/mdx_extra_q.yaml | 0
demucs/{ => demucs}/remote/mdx_q.yaml | 0
demucs/{ => demucs}/remote/repro_mdx_a.yaml | 0
.../remote/repro_mdx_a_hybrid_only.yaml | 0
.../remote/repro_mdx_a_time_only.yaml | 0
demucs/{ => demucs}/repitch.py | 0
demucs/{ => demucs}/repo.py | 0
demucs/{ => demucs}/separate.py | 0
demucs/{ => demucs}/solver.py | 0
demucs/{ => demucs}/spec.py | 0
demucs/{ => demucs}/states.py | 0
demucs/{ => demucs}/svd.py | 0
demucs/{ => demucs}/train.py | 0
demucs/{ => demucs}/transformer.py | 0
demucs/{ => demucs}/utils.py | 0
demucs/{ => demucs}/wav.py | 0
demucs/{ => demucs}/wdemucs.py | 0
demucs/docs/api.md | 204 ++
demucs/docs/linux.md | 28 +
demucs/docs/mac.md | 28 +
demucs/docs/mdx.md | 73 +
demucs/docs/release.md | 114 +
demucs/docs/sdx23.md | 61 +
demucs/docs/training.md | 290 +++
demucs/docs/windows.md | 67 +
demucs/environment-cpu.yml | 28 +
demucs/environment-cuda.yml | 28 +
demucs/hubconf.py | 11 +
demucs/mypy.ini | 5 +
demucs/outputs.tar.gz | Bin 0 -> 1885 bytes
demucs/requirements.txt | 19 +
demucs/requirements_minimal.txt | 10 +
demucs/setup.cfg | 8 +
demucs/setup.py | 75 +
demucs/test.mp3 | Bin 0 -> 802480 bytes
demucs/tools/__init__.py | 5 +
demucs/tools/automix.py | 343 +++
demucs/tools/bench.py | 78 +
demucs/tools/convert.py | 152 ++
demucs/tools/export.py | 71 +
demucs/tools/test_pretrained.py | 43 +
docs/docs/preferences.md | 2 +-
docs/docs/usage/1_file_import.md | 2 +
docs/docs/usage/5_speaker_identification.md | 9 +
flatpak/run-buzz.sh | 2 +
hatch_build.py | 37 +
pyproject.toml | 16 +-
pytest.ini | 1 +
share/applications/buzz.desktop | 17 +
snap/snapcraft.yaml | 12 +-
tests/gui_test.py | 5 +-
.../file_transcriber_queue_worker_test.py | 6 +-
tests/transcriber/whisper_cpp_test.py | 2 +-
.../speaker_identification_widget_test.py | 90 +
tests/widgets/transcription_viewer_test.py | 8 +-
uv.lock | 2024 ++++++++++++++++-
whisper_diarization | 1 +
153 files changed, 7397 insertions(+), 707 deletions(-)
create mode 100644 buzz/assets/speaker-identification.svg
create mode 100644 buzz/widgets/transcription_viewer/speaker_identification_widget.py
create mode 160000 ctc_forced_aligner
create mode 160000 deepmultilingualpunctuation
create mode 100644 demucs/.github/ISSUE_TEMPLATE/bug.md
create mode 100644 demucs/.github/ISSUE_TEMPLATE/question.md
create mode 100644 demucs/.github/workflows/linter.yml
create mode 100644 demucs/.github/workflows/tests.yml
create mode 100644 demucs/.gitignore
create mode 100644 demucs/CODE_OF_CONDUCT.md
create mode 100644 demucs/CONTRIBUTING.md
create mode 100644 demucs/Demucs.ipynb
create mode 100644 demucs/LICENSE
create mode 100644 demucs/MANIFEST.in
create mode 100644 demucs/Makefile
create mode 100644 demucs/README.md
delete mode 100644 demucs/Readme.md
create mode 100644 demucs/conf/config.yaml
create mode 100644 demucs/conf/dset/aetl.yaml
create mode 100644 demucs/conf/dset/auto_extra_test.yaml
create mode 100644 demucs/conf/dset/auto_mus.yaml
create mode 100644 demucs/conf/dset/extra44.yaml
create mode 100644 demucs/conf/dset/extra_mmi_goodclean.yaml
create mode 100644 demucs/conf/dset/extra_test.yaml
create mode 100644 demucs/conf/dset/musdb44.yaml
create mode 100644 demucs/conf/dset/sdx23_bleeding.yaml
create mode 100644 demucs/conf/dset/sdx23_labelnoise.yaml
create mode 100644 demucs/conf/svd/base.yaml
create mode 100644 demucs/conf/svd/base2.yaml
create mode 100644 demucs/conf/svd/default.yaml
create mode 100644 demucs/conf/variant/default.yaml
create mode 100644 demucs/conf/variant/example.yaml
create mode 100644 demucs/conf/variant/finetune.yaml
create mode 100644 demucs/demucs.png
rename demucs/{ => demucs}/__init__.py (100%)
rename demucs/{ => demucs}/__main__.py (100%)
rename demucs/{ => demucs}/api.py (100%)
rename demucs/{ => demucs}/apply.py (100%)
rename demucs/{ => demucs}/audio.py (100%)
rename demucs/{ => demucs}/audio_legacy.py (100%)
rename demucs/{ => demucs}/augment.py (100%)
rename demucs/{ => demucs}/demucs.py (100%)
rename demucs/{ => demucs}/distrib.py (100%)
rename demucs/{ => demucs}/ema.py (100%)
rename demucs/{ => demucs}/evaluate.py (100%)
rename demucs/{ => demucs}/grids/__init__.py (100%)
rename demucs/{ => demucs}/grids/_explorers.py (100%)
rename demucs/{ => demucs}/grids/mdx.py (100%)
rename demucs/{ => demucs}/grids/mdx_extra.py (100%)
rename demucs/{ => demucs}/grids/mdx_refine.py (100%)
rename demucs/{ => demucs}/grids/mmi.py (100%)
rename demucs/{ => demucs}/grids/mmi_ft.py (100%)
rename demucs/{ => demucs}/grids/repro.py (100%)
rename demucs/{ => demucs}/grids/repro_ft.py (100%)
rename demucs/{ => demucs}/grids/sdx23.py (100%)
rename demucs/{ => demucs}/hdemucs.py (100%)
rename demucs/{ => demucs}/htdemucs.py (100%)
rename demucs/{ => demucs}/pretrained.py (100%)
rename demucs/{ => demucs}/py.typed (100%)
rename demucs/{ => demucs}/remote/files.txt (100%)
rename demucs/{ => demucs}/remote/hdemucs_mmi.yaml (100%)
rename demucs/{ => demucs}/remote/htdemucs.yaml (100%)
rename demucs/{ => demucs}/remote/htdemucs_6s.yaml (100%)
rename demucs/{ => demucs}/remote/htdemucs_ft.yaml (100%)
rename demucs/{ => demucs}/remote/mdx.yaml (100%)
rename demucs/{ => demucs}/remote/mdx_extra.yaml (100%)
rename demucs/{ => demucs}/remote/mdx_extra_q.yaml (100%)
rename demucs/{ => demucs}/remote/mdx_q.yaml (100%)
rename demucs/{ => demucs}/remote/repro_mdx_a.yaml (100%)
rename demucs/{ => demucs}/remote/repro_mdx_a_hybrid_only.yaml (100%)
rename demucs/{ => demucs}/remote/repro_mdx_a_time_only.yaml (100%)
rename demucs/{ => demucs}/repitch.py (100%)
rename demucs/{ => demucs}/repo.py (100%)
rename demucs/{ => demucs}/separate.py (100%)
rename demucs/{ => demucs}/solver.py (100%)
rename demucs/{ => demucs}/spec.py (100%)
rename demucs/{ => demucs}/states.py (100%)
rename demucs/{ => demucs}/svd.py (100%)
rename demucs/{ => demucs}/train.py (100%)
rename demucs/{ => demucs}/transformer.py (100%)
rename demucs/{ => demucs}/utils.py (100%)
rename demucs/{ => demucs}/wav.py (100%)
rename demucs/{ => demucs}/wdemucs.py (100%)
create mode 100644 demucs/docs/api.md
create mode 100644 demucs/docs/linux.md
create mode 100644 demucs/docs/mac.md
create mode 100644 demucs/docs/mdx.md
create mode 100644 demucs/docs/release.md
create mode 100644 demucs/docs/sdx23.md
create mode 100644 demucs/docs/training.md
create mode 100644 demucs/docs/windows.md
create mode 100644 demucs/environment-cpu.yml
create mode 100644 demucs/environment-cuda.yml
create mode 100644 demucs/hubconf.py
create mode 100644 demucs/mypy.ini
create mode 100644 demucs/outputs.tar.gz
create mode 100644 demucs/requirements.txt
create mode 100644 demucs/requirements_minimal.txt
create mode 100644 demucs/setup.cfg
create mode 100644 demucs/setup.py
create mode 100644 demucs/test.mp3
create mode 100644 demucs/tools/__init__.py
create mode 100644 demucs/tools/automix.py
create mode 100644 demucs/tools/bench.py
create mode 100644 demucs/tools/convert.py
create mode 100644 demucs/tools/export.py
create mode 100644 demucs/tools/test_pretrained.py
create mode 100644 docs/docs/usage/5_speaker_identification.md
create mode 100644 share/applications/buzz.desktop
create mode 100644 tests/widgets/speaker_identification_widget_test.py
create mode 160000 whisper_diarization
diff --git a/.coveragerc b/.coveragerc
index c8f35eab..566ba584 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -1,9 +1,12 @@
[run]
omit =
buzz/whisper_cpp/*
+ buzz/transcriber/local_whisper_cpp_server_transcriber.py
*_test.py
demucs/*
- buzz/transcriber/local_whisper_cpp_server_transcriber.py
+ whisper_diarization/*
+ deepmultilingualpunctuation/*
+ ctc_forced_aligner/*
[html]
directory = coverage/html
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 010e183a..54e7158d 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -70,10 +70,10 @@ jobs:
~/AppData/Local/Buzz/Buzz/Cache
key: whisper-models
- - uses: AnimMouse/setup-ffmpeg@v1.2.1
+ - uses: AnimMouse/setup-ffmpeg@v1
id: setup-ffmpeg
with:
- version: ${{ matrix.os == 'macos-15-intel' && '7.1.1' || matrix.os == 'macos-latest' && '71' || '7.1' }}
+ version: ${{ matrix.os == 'macos-15-intel' && '7.1.1' || matrix.os == 'macos-latest' && '80' || '8.0' }}
- name: Test ffmpeg
run: ffmpeg -i ./testdata/audio-long.mp3 ./testdata/audio-long.wav
diff --git a/.github/workflows/snapcraft.yml b/.github/workflows/snapcraft.yml
index 0b1ecec3..286fe59c 100644
--- a/.github/workflows/snapcraft.yml
+++ b/.github/workflows/snapcraft.yml
@@ -15,9 +15,22 @@ concurrency:
jobs:
build:
runs-on: ubuntu-latest
+ timeout-minutes: 90
+ env:
+ BUZZ_DISABLE_TELEMETRY: true
outputs:
snap: ${{ steps.snapcraft.outputs.snap }}
steps:
+ # Ideas from https://github.com/orgs/community/discussions/25678
+ - name: Remove unused build tools
+ run: |
+ sudo apt-get remove -y '^llvm-.*'
+ sudo apt-get remove -y 'php.*'
+ sudo apt-get remove -y azure-cli google-cloud-sdk hhvm google-chrome-stable firefox powershell mono-devel || true
+ sudo apt-get autoremove -y
+ sudo apt-get clean
+ python -m pip cache purge
+ rm -rf /opt/hostedtoolcache || true
- name: Maximize build space
uses: easimon/maximize-build-space@master
with:
diff --git a/.gitignore b/.gitignore
index f0c01776..66f3b3ec 100644
--- a/.gitignore
+++ b/.gitignore
@@ -31,4 +31,5 @@ benchmarks.json
/coverage/
/wheelhouse/
/.flatpak-builder
-/repo
\ No newline at end of file
+/repo
+/nemo_msdd_configs
diff --git a/.gitmodules b/.gitmodules
index fa83e220..1c0c8b24 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,15 @@
[submodule "whisper.cpp"]
path = whisper.cpp
url = https://github.com/ggerganov/whisper.cpp
+[submodule "whisper_diarization"]
+ path = whisper_diarization
+ url = https://github.com/MahmoudAshraf97/whisper-diarization
+[submodule "demucs"]
+ path = demucs
+ url = https://github.com/MahmoudAshraf97/demucs.git
+[submodule "deepmultilingualpunctuation"]
+ path = deepmultilingualpunctuation
+ url = https://github.com/oliverguhr/deepmultilingualpunctuation.git
+[submodule "ctc_forced_aligner"]
+ path = ctc_forced_aligner
+ url = https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
diff --git a/Buzz.spec b/Buzz.spec
index 0f4e8edb..c2d93bb1 100644
--- a/Buzz.spec
+++ b/Buzz.spec
@@ -30,7 +30,13 @@ datas += collect_data_files("transformers", include_py_files=True)
datas += collect_data_files("faster_whisper", include_py_files=True)
datas += collect_data_files("stable_whisper", include_py_files=True)
datas += collect_data_files("whisper")
-datas += [("demucs", "demucs")]
+datas += collect_data_files("demucs", include_py_files=True)
+datas += collect_data_files("whisper_diarization", include_py_files=True)
+datas += collect_data_files("deepmultilingualpunctuation", include_py_files=True)
+datas += collect_data_files("ctc_forced_aligner", include_py_files=True)
+datas += collect_data_files("nemo", include_py_files=True)
+datas += collect_data_files("lightning_fabric", include_py_files=True)
+datas += collect_data_files("pytorch_lightning", include_py_files=True)
datas += [("buzz/assets/*", "assets")]
datas += [("buzz/locale", "locale")]
datas += [("buzz/schema.sql", ".")]
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index d8a540cf..3b2fddf4 100755
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -53,8 +53,7 @@ sudo apt-get install --no-install-recommends libyaml-dev libtbb-dev libxkbcommon
```
On versions prior to Ubuntu 24.04 install `sudo apt-get install --no-install-recommends libegl1-mesa`
5. Install the dependencies `uv sync`
-6. Build Buzz `uv build`
-7. Run Buzz `uv run buzz`
+6. Run Buzz `uv run buzz`
#### Necessary dependencies for Faster Whisper on GPU
@@ -81,8 +80,7 @@ On versions prior to Ubuntu 24.04 install `sudo apt-get install --no-install-rec
3. Install uv `curl -LsSf https://astral.sh/uv/install.sh | sh` (or `brew install uv`)
4. Install system dependencies you may be missing `brew install ffmpeg`
5. Install the dependencies `uv sync`
-6. Build Buzz `uv build`
-7. Run Buzz `uv run buzz`
+6. Run Buzz `uv run buzz`
diff --git a/Makefile b/Makefile
index af2aa9a1..92315dbd 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,5 @@
-version := 1.3.4
+# Change also in pyproject.toml and buzz/__version__.py
+version := 1.4.0
mac_app_path := ./dist/Buzz.app
mac_zip_path := ./dist/Buzz-${version}-mac.zip
@@ -28,7 +29,7 @@ else
rm -rf dist/* || true
endif
-COVERAGE_THRESHOLD := 75
+COVERAGE_THRESHOLD := 70
test: buzz/whisper_cpp
pytest -s -vv --cov=buzz --cov-report=xml --cov-report=html --benchmark-skip --cov-fail-under=${COVERAGE_THRESHOLD} --cov-config=.coveragerc
@@ -67,7 +68,7 @@ ifeq ($(shell uname -s), Linux)
cp whisper.cpp/build/bin/whisper-server buzz/whisper_cpp/ || true
cp whisper.cpp/build/src/libwhisper.so buzz/whisper_cpp/ || true
cp whisper.cpp/build/src/libwhisper.so.1 buzz/whisper_cpp/ || true
- cp whisper.cpp/build/src/libwhisper.so.1.7.6 buzz/whisper_cpp/ || true
+ cp whisper.cpp/build/src/libwhisper.so.1.8.2 buzz/whisper_cpp/ || true
cp whisper.cpp/build/ggml/src/libggml.so buzz/whisper_cpp/ || true
cp whisper.cpp/build/ggml/src/libggml-base.so buzz/whisper_cpp/ || true
cp whisper.cpp/build/ggml/src/libggml-cpu.so buzz/whisper_cpp/ || true
diff --git a/README.md b/README.md
index 55c62f9d..7b5db725 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@ OpenAI's [Whisper](https://github.com/openai/whisper).
[](https://GitHub.com/chidiwilliams/buzz/releases/)
-
Buzz is better on the App Store. Get a Mac-native version of Buzz with a cleaner look, audio playback, drag-and-drop import, transcript editing, search, and much more.
+
An older version of Buzz available on the App Store. Get a Mac-native version of Buzz with a cleaner look, audio playback, drag-and-drop import, transcript editing, search, and much more.
diff --git a/buzz/__version__.py b/buzz/__version__.py
index 4a16f216..af63e4ae 100644
--- a/buzz/__version__.py
+++ b/buzz/__version__.py
@@ -1 +1 @@
-VERSION = "1.3.4"
+VERSION = "1.4.0"
diff --git a/buzz/assets/speaker-identification.svg b/buzz/assets/speaker-identification.svg
new file mode 100644
index 00000000..cfea8b41
--- /dev/null
+++ b/buzz/assets/speaker-identification.svg
@@ -0,0 +1,14 @@
+
+
\ No newline at end of file
diff --git a/buzz/buzz.py b/buzz/buzz.py
index e6f755f7..6c4750d6 100644
--- a/buzz/buzz.py
+++ b/buzz/buzz.py
@@ -56,6 +56,17 @@ def main():
format=log_format,
)
+ # Silence noisy third-party library loggers
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
+ logging.getLogger("graphviz").setLevel(logging.WARNING)
+ logging.getLogger("nemo_logger").setLevel(logging.ERROR)
+ logging.getLogger("numba").setLevel(logging.WARNING)
+ logging.getLogger("torio._extension.utils").setLevel(logging.WARNING)
+ logging.getLogger("export_config_manager").setLevel(logging.WARNING)
+ logging.getLogger("training_telemetry_provider").setLevel(logging.ERROR)
+ logging.getLogger("default_recorder").setLevel(logging.WARNING)
+ logging.getLogger("config").setLevel(logging.WARNING)
+
if getattr(sys, "frozen", False) is False:
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.DEBUG)
diff --git a/buzz/file_transcriber_queue_worker.py b/buzz/file_transcriber_queue_worker.py
index f6cf91fb..b056981f 100644
--- a/buzz/file_transcriber_queue_worker.py
+++ b/buzz/file_transcriber_queue_worker.py
@@ -7,7 +7,7 @@ from uuid import UUID
from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot
-from demucs import api as demucsApi
+from demucs.demucs import api as demucsApi
from buzz.model_loader import ModelType
from buzz.transcriber.file_transcriber import FileTranscriber
diff --git a/buzz/locale/ca_ES/LC_MESSAGES/buzz.po b/buzz/locale/ca_ES/LC_MESSAGES/buzz.po
index e88359f6..aaf56614 100644
--- a/buzz/locale/ca_ES/LC_MESSAGES/buzz.po
+++ b/buzz/locale/ca_ES/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: buzz\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2025-10-17 07:59+0200\n"
"Last-Translator: Éric Duarte \n"
"Language-Team: Catalan \n"
@@ -554,64 +554,68 @@ msgstr "Veure"
msgid "Timestamps"
msgstr "Marqua de temps"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Exporta"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Traduir"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "Redimensionar"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr "Cerca"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr "Mostra/amaga la barra de cerca (Ctrl+F)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr "Cerca:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr "Introduïu el text a cercar..."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr "Coincidència anterior (Maj+Retorn)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr "Coincidència següent (retorn)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr "Neteja"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr "Controls de reproducció:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr "Segment de bucle"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr "Activa/desactiva el bucle en fer clic als segments de transcripció"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr "Segueix l'àudio"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
@@ -619,75 +623,146 @@ msgstr ""
"Activa/desactiva seguint la posició d'àudio actual a la transcripció. Quan "
"està activada, es desplaça automàticament al text actual."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr "Desplaça't fins a l'actual"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr "Desplaçar-se fins al text que es parla actualment"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr "1 de més de 100 coincidències"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr "1 de "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr " coincidències"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr "No s'ha trobat cap coincidència"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr " de més de 100 coincidències"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr " de "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "Clau API necessària"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Introduïu la clau API d'OpenAI a les preferències"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr "Opcions de redimensionament"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "Longitud desitjada dels subtítols"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "Opcions de fusió"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "Fusiona per buit"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "Divideix per puntuació"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "Divideix per la longitud màxima"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "Fusiona"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Cancel·la la transcripció"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Desa el fitxer"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "Desa el fitxer"
diff --git a/buzz/locale/da_DK/LC_MESSAGES/buzz.po b/buzz/locale/da_DK/LC_MESSAGES/buzz.po
index 7d356c67..fe698374 100644
--- a/buzz/locale/da_DK/LC_MESSAGES/buzz.po
+++ b/buzz/locale/da_DK/LC_MESSAGES/buzz.po
@@ -2,7 +2,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: \n"
"Last-Translator: Ole Guldberg2 \n"
"Language-Team: \n"
@@ -552,138 +552,213 @@ msgstr "Vis"
msgid "Timestamps"
msgstr "Tidsstempler"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Eksporter"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Oversæt"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "Behandel størrelse"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "API-nøgle påkrævet"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Indtast venligst OpenAI API-nøgle i indstillinger"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr "Størrelsesindstillinger"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "Ønskede undertekst længde"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "Sammenfletningsindstillinger"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "Sammenflet ved hul"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "Split ved punktum"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "Split ved max længde"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "Sammenflet"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Afbryd transkription"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Gem fil"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "Gem fil"
diff --git a/buzz/locale/de_DE/LC_MESSAGES/buzz.po b/buzz/locale/de_DE/LC_MESSAGES/buzz.po
index 14ebc504..1b547455 100644
--- a/buzz/locale/de_DE/LC_MESSAGES/buzz.po
+++ b/buzz/locale/de_DE/LC_MESSAGES/buzz.po
@@ -6,7 +6,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2025-03-05 14:41+0100\n"
"Last-Translator: \n"
"Language-Team: \n"
@@ -552,138 +552,213 @@ msgstr "Anzeigen"
msgid "Timestamps"
msgstr "Zeitstempel"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Export"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Übersetzen"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "Größe ändern"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "API-Schlüssel erforderlich"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Bitte geben Sie den OpenAI-API-Schlüssel in den Einstellungen ein"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr "Größenänderungsoptionen"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "Gewünschte Untertitellänge"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "Zusammenführungsoptionen"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "Nach Abstand zusammenführen"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "Durch Satzzeichen getrennt"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "Aufgeteilt nach maximaler Länge"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "Vereinigen"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Transkription abbrechen"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Datei speichern"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "Datei speichern"
diff --git a/buzz/locale/en_US/LC_MESSAGES/buzz.po b/buzz/locale/en_US/LC_MESSAGES/buzz.po
index 87f47cea..02bac9f4 100644
--- a/buzz/locale/en_US/LC_MESSAGES/buzz.po
+++ b/buzz/locale/en_US/LC_MESSAGES/buzz.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language-Team: LANGUAGE \n"
@@ -540,138 +540,211 @@ msgstr ""
msgid "Timestamps"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr ""
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+msgid "5/8 Preparing transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+msgid "Save"
+msgstr ""
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr ""
diff --git a/buzz/locale/es_ES/LC_MESSAGES/buzz.po b/buzz/locale/es_ES/LC_MESSAGES/buzz.po
index 133209e1..1c7d3e0c 100644
--- a/buzz/locale/es_ES/LC_MESSAGES/buzz.po
+++ b/buzz/locale/es_ES/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2025-09-08 12:43+0200\n"
"Last-Translator: Éric Duarte \n"
"Language-Team: \n"
@@ -589,66 +589,70 @@ msgstr "Ver"
msgid "Timestamps"
msgstr "Marcas de tiempo"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Exportar"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Traducir"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "Cambiar el tamaño"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr "Buscar"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr "Mostrar/Ocultar barra de búsqueda (Ctrl+F)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr "Encontrar:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr "Introducir texto para encontrar..."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr "Coincidencia anterior (Mayús+Intro)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr "Siguiente coincidencia (Enter)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr "Limpiar"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr "Controles de reproducción:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr "Segmento de bucle"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
"Activar/desactivar la reproducción en bucle al hacer clic en segmentos de la "
"transcripción"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr "Seguir audio"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
@@ -657,75 +661,148 @@ msgstr ""
"transcripción. Cuando está activado, se desplaza automáticamente al texto "
"actual."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr "Desplácese hasta Actual"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr "Desplazarse hasta el texto hablado actualmente"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr "1 de 100+ coincidencias"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr "1 de "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr " coincidencias"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr "No se encontraron coincidencias"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr " de 100+ coincidencias"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr " de "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "Clave de API requerida"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Ingrese la clave API de OpenAI en las preferencias"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr "Opciones de cambio de tamaño"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "Longitud deseada de los subtítulos"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "Opciones de fusión"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "Fusión por hueco"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "Dividido por puntuación"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "Dividido por la longitud máxima"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "Fusión"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+# automatic translation
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Cancelar transcripción"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+# automatic translation
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Guardar archivo"
+
# automatic translation
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
diff --git a/buzz/locale/it_IT/LC_MESSAGES/buzz.po b/buzz/locale/it_IT/LC_MESSAGES/buzz.po
index 127a3b0c..fd231e6a 100644
--- a/buzz/locale/it_IT/LC_MESSAGES/buzz.po
+++ b/buzz/locale/it_IT/LC_MESSAGES/buzz.po
@@ -6,7 +6,7 @@ msgid ""
msgstr ""
"Project-Id-Version: buzz\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2025-11-09 20:22+0200\n"
"Language-Team: (Italiano) Albano Battistella \n"
"Language: it_IT\n"
@@ -555,65 +555,69 @@ msgstr "Visualizza"
msgid "Timestamps"
msgstr "Timestamp"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Esporta"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Tradurre"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "Ridimensionare"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr "Trova"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr "Mostra/Nascondi barra di ricerca (Ctrl+F)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr "Trova:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr "Inserisci il testo per trovare..."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr "Corrispondenza precedente (Maiusc+Invio)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr "Prossima corrispondenza (Invio)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr "Elimina"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr "Controlli di riproduzione:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr "Ciclo di segmento"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
"Abilita/disabilita il loop quando si fa clic sui segmenti della trascrizione"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr "Segui Audio"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
@@ -622,75 +626,146 @@ msgstr ""
"trascrizione. Quando abilitato, scorre automaticamente fino al testo "
"corrente."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr "Scorri fino al Corrente"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr "Scorrere fino al testo attualmente pronunciato"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr "1 di 100+ corrispondenze"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr "1 di"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr "corrispondenze"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr "Nessuna corrispondenza trovata"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr " di oltre 100 corrispondenze"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr " di "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "Chiave API richiesta"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Inserisci la chiave API OpenAI nelle preferenze"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr "Opzioni di ridimensionamento"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "Lunghezza desiderata dei sottotitoli"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "Opzioni di unione"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "Unito per spazio"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "Diviso per punteggiatura"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "Diviso per lunghezza massima"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "Unione"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Inizio trascrizione..."
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Salva file"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "Salva file"
diff --git a/buzz/locale/ja_JP/LC_MESSAGES/buzz.po b/buzz/locale/ja_JP/LC_MESSAGES/buzz.po
index b5ec2b11..2bdda8b2 100644
--- a/buzz/locale/ja_JP/LC_MESSAGES/buzz.po
+++ b/buzz/locale/ja_JP/LC_MESSAGES/buzz.po
@@ -2,7 +2,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: \n"
"Last-Translator: nunawa <71294849+nunawa@users.noreply.github.com>\n"
"Language-Team: \n"
@@ -548,139 +548,214 @@ msgstr "表示"
msgid "Timestamps"
msgstr "タイムスタンプ"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "出力"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "翻訳"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "リサイズ"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "APIキーが必要"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "設定画面でOpenAI APIキーを入力してください"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
#, fuzzy
msgid "Resize Options"
msgstr "リサイズ"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "希望する字幕の長さ"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr ""
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "文字起こしをキャンセルする"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "ファイルを保存"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "ファイルを保存"
diff --git a/buzz/locale/lv_LV/LC_MESSAGES/buzz.po b/buzz/locale/lv_LV/LC_MESSAGES/buzz.po
index 18f799f5..df528784 100644
--- a/buzz/locale/lv_LV/LC_MESSAGES/buzz.po
+++ b/buzz/locale/lv_LV/LC_MESSAGES/buzz.po
@@ -3,13 +3,12 @@
# This file is distributed under the same license as the PACKAGE package.
# FIRST AUTHOR , YEAR.
#
-#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
-"PO-Revision-Date: 2025-11-28 16:50+0200\n"
+"POT-Creation-Date: 2025-11-23 13:02+0200\n"
+"PO-Revision-Date: 2025-11-23 12:58+0200\n"
"Last-Translator: \n"
"Language-Team: \n"
"Language: lv_LV\n"
@@ -558,64 +557,68 @@ msgstr "Skats"
msgid "Timestamps"
msgstr "Laiks"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Eksportēt"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Tulkot"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "Mainīt garumu"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr "Noteikt runātājus"
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr "Meklēt"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr "Rādīt/Slēpt meklēšanas joslu (Ctrl+F)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr "Meklēt:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr "Ievadiet meklējamo..."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr "Iepriekšējais rezultāts (Shift+Enter)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr "Nākamais rezultāts (Enter)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr "Notīrīt"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr "Atskaņošanas iespējas:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr "Atkārtot segmentu"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr "Nosaka vai atkārtot izvēlēto segmentu"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr "Sekot audio"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
@@ -623,75 +626,144 @@ msgstr ""
"Nosaka vai atskaņojot audio iezīmētajam segmentam vajadzētu automātiski "
"sekot tam kas tiek atskaņots."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr "Pāriet uz tekošo"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr "Pāriet uz šobrīd atskaņojamo tesktu"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr "1 no 100+ "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr "1 no "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr " "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr "Nekas nav atrasts"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr " no 100+"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr " no "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "API atslēgas kļūda"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Lūdzu ievadiet OpenAI API atslēgu iestatījumos"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr "Garuma maiņas iestatījumi"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "Vēlamais teksta garums"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "Apvienošanas iestatījumi"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "Apvienot pēc attāluma"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "Dalīt pie pieturzīmēm"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "Dalīt pie maksimālā garuma"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "Apvienot"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr "1/8 Apkopo transkripcijas"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr "2/8 Ielādē audio"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr "3/8 Ielādē identifikācijas modeli"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr "4/8 Apstrādā audio"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+msgid "5/8 Preparing transcripts"
+msgstr "5/8 Sagatavo transkripcijas"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr "6/8 Nosaka runātājus"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr "0/0 Kļūda nosakot runātājus"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr "7/8 Marķē runātāju teikumus"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr "8/8 Runātāju noteikšana pabeigta"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr "1. solis: Runātāju noteikšana"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr "Noteikt"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr "Gatavs noteikt runātājus"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr "Audio datne nav atrasta"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr "2. solis: Runātāju identifikācija"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr "Atskaņot paraugu"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr "Apvienot secīgus runātāja teikumus"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+msgid "Save"
+msgstr "Saglabāt"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "Saglabāt failu"
diff --git a/buzz/locale/nl/LC_MESSAGES/buzz.po b/buzz/locale/nl/LC_MESSAGES/buzz.po
index 2e21acc2..75b59ea1 100644
--- a/buzz/locale/nl/LC_MESSAGES/buzz.po
+++ b/buzz/locale/nl/LC_MESSAGES/buzz.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2025-03-20 18:30+0100\n"
"Last-Translator: Heimen Stoffels \n"
"Language-Team: none\n"
@@ -552,138 +552,213 @@ msgstr "Bekijken"
msgid "Timestamps"
msgstr "Tijdstippen"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Exporteren"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Vertalen"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "Grootte"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "Api-sleutel vereist"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Voer de OpenAI-api-sleutel in in de instellingen"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr "Grootteopties"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "Voorkeurslengte van ondertiteling"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "Samenvoegopties"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "Samenvoegen op basis van tussenruimte"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "Splitsen op basis van leestekens"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "Splitsen op basis van max. lengte"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "Samenvoegen"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Transcriptie wissen"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Bestand opslaan"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "Bestand opslaan"
diff --git a/buzz/locale/pl_PL/LC_MESSAGES/buzz.po b/buzz/locale/pl_PL/LC_MESSAGES/buzz.po
index 2d452294..261fcd5b 100644
--- a/buzz/locale/pl_PL/LC_MESSAGES/buzz.po
+++ b/buzz/locale/pl_PL/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2024-03-17 20:50+0200\n"
"Last-Translator: \n"
"Language-Team: \n"
@@ -561,138 +561,213 @@ msgstr ""
msgid "Timestamps"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr ""
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Anuluj transkrypcję"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Zapisz plik"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
#, fuzzy
msgid "Save File"
diff --git a/buzz/locale/pt_BR/LC_MESSAGES/buzz.po b/buzz/locale/pt_BR/LC_MESSAGES/buzz.po
index 25165acd..39ae4c38 100644
--- a/buzz/locale/pt_BR/LC_MESSAGES/buzz.po
+++ b/buzz/locale/pt_BR/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: Buzz\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2025-11-01 17:43-0300\n"
"Last-Translator: Paulo Schopf \n"
"Language-Team: none\n"
@@ -552,64 +552,68 @@ msgstr "Visualizar"
msgid "Timestamps"
msgstr "Marcações de tempo"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Exportar"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Traduzir"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "Redimensionar"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr "Procurar"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr "Mostrar/Ocultar a Barra de Pesquisa"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr "Procurar:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr "Digite o texto a procurar..."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr "Encontro prévio (Shift+Enter)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr "Póximo encontro (Enter)"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr "Limpar"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr "Controles de Reprodução:"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr "Segmento de Loop"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr "Habilitar/desabilitar loop ao clicar em segmentos de transcrição"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr "Siga o Áudio"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
@@ -617,75 +621,146 @@ msgstr ""
"Ativar/desativar a opção de seguir a posição atual do áudio na transcrição. "
"Quando ativado, rola automaticamente para o texto atual."
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr "Rolar para o Atual"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr "Role até o texto falado no momento"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr "1 de 100+ encontros"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr "1 de "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr " encontros"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr "Nada encontrado"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr " de 100+ encontros"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr " de "
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "Chave API Necessária"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Insira a chave API OpenAI nas preferências"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr "Opções de Redimensionamento"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "Duração desejada da legenda"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "Opções de Mesclagem"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "Mesclar por intervalo"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "Dividir por pontuação"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "Dividir por tamanho máximo"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "Mesclar"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Iniciando transcrição..."
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Salvar Arquivo"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "Salvar Arquivo"
diff --git a/buzz/locale/uk_UA/LC_MESSAGES/buzz.po b/buzz/locale/uk_UA/LC_MESSAGES/buzz.po
index f45a5184..2ef57f95 100644
--- a/buzz/locale/uk_UA/LC_MESSAGES/buzz.po
+++ b/buzz/locale/uk_UA/LC_MESSAGES/buzz.po
@@ -2,7 +2,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: \n"
"Last-Translator: Yevhen Popok \n"
"Language-Team: \n"
@@ -550,138 +550,213 @@ msgstr "Вигляд"
msgid "Timestamps"
msgstr "Позначки часу"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "Експорт"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "Перекласти"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "Потрібен API-ключ"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "Будь ласка, введіть API-ключ OpenAI в налаштуваннях"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr ""
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "Скасувати транскрипцію"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "Зберегти файл"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
msgid "Save File"
msgstr "Зберегти файл"
diff --git a/buzz/locale/zh_CN/LC_MESSAGES/buzz.po b/buzz/locale/zh_CN/LC_MESSAGES/buzz.po
index 0e9154a2..5cdb3a7b 100644
--- a/buzz/locale/zh_CN/LC_MESSAGES/buzz.po
+++ b/buzz/locale/zh_CN/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2023-05-01 15:45+0800\n"
"Last-Translator: \n"
"Language-Team: lamb \n"
@@ -562,139 +562,214 @@ msgstr "查看"
msgid "Timestamps"
msgstr "时间戳"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr "导出"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr "翻译"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr "调整大小"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr "需要API Key"
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr "请在偏好设置中输入OpenAI API Key"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
#, fuzzy
msgid "Resize Options"
msgstr "调整大小"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr "所需字幕长度"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr "合并选项"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr "按间隔合并"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr "按标点符号拆分"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr "按最大长度拆分"
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr "合并"
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "取消识别"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "保存文件"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
#, fuzzy
msgid "Save File"
diff --git a/buzz/locale/zh_TW/LC_MESSAGES/buzz.po b/buzz/locale/zh_TW/LC_MESSAGES/buzz.po
index ed67c2c8..fd0fe400 100644
--- a/buzz/locale/zh_TW/LC_MESSAGES/buzz.po
+++ b/buzz/locale/zh_TW/LC_MESSAGES/buzz.po
@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2025-11-28 16:49+0200\n"
+"POT-Creation-Date: 2025-11-23 12:55+0200\n"
"PO-Revision-Date: 2023-05-01 15:45+0800\n"
"Last-Translator: \n"
"Language-Team: Lamb\n"
@@ -557,138 +557,213 @@ msgstr ""
msgid "Timestamps"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:211
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:215
msgid "Export"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:230
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:234
msgid "Translate"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:240
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:177
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:244
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:175
msgid "Resize"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:252
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:257
+msgid "Identify Speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:269
msgid "Find"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:255
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:272
msgid "Show/Hide Search Bar (Ctrl+F)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:320
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:337
msgid "Find:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:326
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:343
msgid "Enter text to find..."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:339
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:356
msgid "Previous match (Shift+Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:347
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:364
msgid "Next match (Enter)"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:355
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:372
msgid "Clear"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:382
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:399
msgid "Playback Controls:"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:387
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:404
msgid "Loop Segment"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:389
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:406
msgid "Enable/disable looping when clicking on transcript segments"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:395
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:412
msgid "Follow Audio"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:397
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:414
msgid ""
"Enable/disable following the current audio position in the transcript. When "
"enabled, automatically scrolls to current text."
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:444
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:461
msgid "Scroll to Current"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:446
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:463
msgid "Scroll to the currently spoken text"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:768
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:785
msgid "1 of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
msgid "1 of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:770
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:787
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:775
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:792
msgid "No matches found"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:834
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:851
msgid " of 100+ matches"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:836
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:853
msgid " of "
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1191
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1208
msgid "API Key Required"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1192
+#: buzz/widgets/transcription_viewer/transcription_viewer_widget.py:1209
msgid "Please enter OpenAI API Key in preferences"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:159
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:157
msgid "Resize Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:170
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:168
msgid "Desired subtitle length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:195
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:193
msgid "Merge Options"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:206
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:204
msgid "Merge by gap"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:214
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:212
msgid "Split by punctuation"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:222
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:220
msgid "Split by max length"
msgstr ""
-#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:234
+#: buzz/widgets/transcription_viewer/transcription_resizer_widget.py:232
msgid "Merge"
msgstr ""
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:92
+msgid "1/8 Collecting transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:106
+msgid "2/8 Loading audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:115
+msgid "3/8 Loading alignment model"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:121
+msgid "4/8 Processing audio"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:133
+#, fuzzy
+msgid "5/8 Preparing transcripts"
+msgstr "取消錄製"
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:151
+msgid "6/8 Identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:160
+msgid "0/0 Error identifying speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:168
+msgid "7/8 Mapping speakers to transcripts"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:207
+msgid "8/8 Identification done"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:243
+msgid "Step 1: Identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:255
+msgid "Identify"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:265
+msgid "Ready to identify speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:267
+msgid "Audio file not found"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:283
+msgid "Step 2: Name speakers"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:298
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:391
+msgid "Play sample"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:313
+msgid "Merge speaker sentences"
+msgstr ""
+
+#: buzz/widgets/transcription_viewer/speaker_identification_widget.py:318
+#, fuzzy
+msgid "Save"
+msgstr "檔案"
+
#: buzz/widgets/transcription_viewer/export_transcription_menu.py:82
#, fuzzy
msgid "Save File"
diff --git a/buzz/transcriber/whisper_cpp.py b/buzz/transcriber/whisper_cpp.py
index 8f12aec8..7bea69fb 100644
--- a/buzz/transcriber/whisper_cpp.py
+++ b/buzz/transcriber/whisper_cpp.py
@@ -4,7 +4,6 @@ import sys
import logging
import subprocess
import json
-import tempfile
from typing import List
from buzz.assets import APP_BASE_DIR
from buzz.transcriber.transcriber import Segment, Task, FileTranscriptionTask
@@ -58,9 +57,7 @@ class WhisperCpp:
file_to_process = task.file_path
if file_ext not in supported_formats:
- # Create temporary WAV file
- temp_dir = tempfile.gettempdir()
- temp_file = os.path.join(temp_dir, f"buzz_temp_{os.path.basename(task.file_path)}.wav")
+ temp_file = task.file_path + ".wav"
logging.info(f"Converting {task.file_path} to WAV format")
diff --git a/buzz/widgets/icon.py b/buzz/widgets/icon.py
index cac92525..1efca875 100644
--- a/buzz/widgets/icon.py
+++ b/buzz/widgets/icon.py
@@ -82,6 +82,10 @@ class ResizeIcon(Icon):
def __init__(self, parent: QWidget):
super().__init__(get_path("assets/resize_black.svg"), parent)
+class SpeakerIdentificationIcon(Icon):
+ def __init__(self, parent: QWidget):
+ super().__init__(get_path("assets/speaker-identification.svg"), parent)
+
class VisibilityIcon(Icon):
def __init__(self, parent: QWidget):
super().__init__(
diff --git a/buzz/widgets/transcription_viewer/speaker_identification_widget.py b/buzz/widgets/transcription_viewer/speaker_identification_widget.py
new file mode 100644
index 00000000..cbbe6216
--- /dev/null
+++ b/buzz/widgets/transcription_viewer/speaker_identification_widget.py
@@ -0,0 +1,504 @@
+import re
+import os
+import logging
+import faster_whisper
+import torch
+import random
+from typing import Optional
+from PyQt6.QtMultimedia import QMediaPlayer, QAudioOutput
+from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QUrl, QTimer
+from PyQt6.QtGui import QFont
+from PyQt6.QtWidgets import (
+ QWidget,
+ QFormLayout,
+ QVBoxLayout,
+ QHBoxLayout,
+ QLabel,
+ QProgressBar,
+ QPushButton,
+ QCheckBox,
+ QGroupBox,
+ QSpacerItem,
+ QSizePolicy,
+ QLayout,
+)
+from buzz.locale import _
+from buzz.db.entity.transcription import Transcription
+from buzz.db.service.transcription_service import TranscriptionService
+from buzz.paths import file_path_as_title
+from buzz.settings.settings import Settings
+from buzz.widgets.line_edit import LineEdit
+from buzz.transcriber.transcriber import Segment
+
+from ctc_forced_aligner.ctc_forced_aligner import (
+ generate_emissions,
+ get_alignments,
+ get_spans,
+ load_alignment_model,
+ postprocess_results,
+ preprocess_text,
+)
+from whisper_diarization.helpers import (
+ get_realigned_ws_mapping_with_punctuation,
+ get_sentences_speaker_mapping,
+ get_words_speaker_mapping,
+ langs_to_iso,
+ punct_model_langs,
+)
+from deepmultilingualpunctuation.deepmultilingualpunctuation import PunctuationModel
+from whisper_diarization.diarization import MSDDDiarizer
+
+SENTENCE_END = re.compile(r'.*[.!?。!?]')
+
+class IdentificationWorker(QObject):
+ finished = pyqtSignal(list)
+ progress_update = pyqtSignal(str)
+
+ def __init__(self, transcription, transcription_service):
+ super().__init__()
+ self.transcription = transcription
+ self.transcription_service = transcription_service
+
+ def get_transcript(self, audio, **kwargs) -> dict:
+ buzz_segments = self.transcription_service.get_transcription_segments(
+ transcription_id=self.transcription.id_as_uuid
+ )
+
+ segments = []
+ words = []
+ text = ""
+ for buzz_segment in buzz_segments:
+ words.append({
+ 'word': buzz_segment.text + " ",
+ 'start': buzz_segment.start_time / 100,
+ 'end': buzz_segment.end_time / 100,
+ })
+ text += buzz_segment.text + " "
+
+ if SENTENCE_END.match(buzz_segment.text):
+ segments.append({
+ 'text': text,
+ 'words': words
+ })
+ words = []
+ text = ""
+
+ return {
+ 'language': self.transcription.language,
+ 'segments': segments
+ }
+
+ def run(self):
+ self.progress_update.emit(_("1/8 Collecting transcripts"))
+
+ # Step 1 - Get transcript
+ # TODO - Add detected language to the transcript, detect and store separately in metadata
+ # Will also be relevant for template parsing of transcript file names
+ # - See diarize.py for example on how to get this info from whisper transcript, maybe other whisper models also have it
+ language = self.transcription.language if self.transcription.language else "en"
+
+ segments = self.transcription_service.get_transcription_segments(
+ transcription_id=self.transcription.id_as_uuid
+ )
+
+ full_transcript = "".join(segment.text for segment in segments)
+
+ self.progress_update.emit(_("2/8 Loading audio"))
+ audio_waveform = faster_whisper.decode_audio(self.transcription.file)
+
+ # Step 2 - Forced alignment
+ force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
+ use_cuda = torch.cuda.is_available() and force_cpu == "false"
+ device = "cuda" if use_cuda else "cpu"
+ torch_dtype = torch.float16 if use_cuda else torch.float32
+
+ self.progress_update.emit(_("3/8 Loading alignment model"))
+ alignment_model, alignment_tokenizer = load_alignment_model(
+ device,
+ dtype=torch_dtype,
+ )
+
+ self.progress_update.emit(_("4/8 Processing audio"))
+ emissions, stride = generate_emissions(
+ alignment_model,
+ torch.from_numpy(audio_waveform)
+ .to(alignment_model.dtype)
+ .to(alignment_model.device),
+ batch_size=8,
+ )
+
+ del alignment_model
+ torch.cuda.empty_cache()
+
+ self.progress_update.emit(_("5/8 Preparing transcripts"))
+ tokens_starred, text_starred = preprocess_text(
+ full_transcript,
+ romanize=True,
+ language=langs_to_iso[language],
+ )
+
+ segments, scores, blank_token = get_alignments(
+ emissions,
+ tokens_starred,
+ alignment_tokenizer,
+ )
+
+ spans = get_spans(tokens_starred, segments, blank_token)
+
+ word_timestamps = postprocess_results(text_starred, spans, stride, scores)
+
+ # Step 3 - Diarization
+ self.progress_update.emit(_("6/8 Identifying speakers"))
+
+ try:
+ diarizer_model = MSDDDiarizer(device)
+ speaker_ts = diarizer_model.diarize(torch.from_numpy(audio_waveform).unsqueeze(0))
+
+ except Exception as e:
+ self.progress_update.emit(_("0/0 Error identifying speakers"))
+ logging.error(f"Error during diarization: {e}")
+ return
+ finally:
+ del diarizer_model
+ torch.cuda.empty_cache()
+
+ # Step 4 - Reading timestamps <> Speaker Labels mapping
+ self.progress_update.emit(_("7/8 Mapping speakers to transcripts"))
+
+ wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")
+
+ if language in punct_model_langs:
+ # restoring punctuation in the transcript to help realign the sentences
+ punct_model = PunctuationModel(model="kredor/punctuate-all")
+
+ words_list = list(map(lambda x: x["word"], wsm))
+
+ labled_words = punct_model.predict(words_list, chunk_size=230)
+
+ ending_puncts = ".?!。!?"
+ model_puncts = ".,;:!?。!?"
+
+ # We don't want to punctuate U.S.A. with a period. Right?
+ is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
+
+ for word_dict, labeled_tuple in zip(wsm, labled_words):
+ word = word_dict["word"]
+ if (
+ word
+ and labeled_tuple[1] in ending_puncts
+ and (word[-1] not in model_puncts or is_acronym(word))
+ ):
+ word += labeled_tuple[1]
+ if word.endswith(".."):
+ word = word.rstrip(".")
+ word_dict["word"] = word
+
+ else:
+ logging.warning(
+ f"Punctuation restoration is not available for {language} language."
+ " Using the original punctuation."
+ )
+
+ wsm = get_realigned_ws_mapping_with_punctuation(wsm)
+ ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
+
+ self.progress_update.emit(_("8/8 Identification done"))
+ self.finished.emit(ssm)
+
+
+class SpeakerIdentificationWidget(QWidget):
+ resize_button_clicked = pyqtSignal()
+ transcription: Transcription
+ settings = Settings()
+
+ def __init__(
+ self,
+ transcription: Transcription,
+ transcription_service: TranscriptionService,
+ parent: Optional["QWidget"] = None,
+ flags: Qt.WindowType = Qt.WindowType.Widget,
+ transcriptions_updated_signal: Optional[pyqtSignal] = None,
+ ) -> None:
+ super().__init__(parent, flags)
+ self.transcription = transcription
+ self.transcription_service = transcription_service
+ self.transcriptions_updated_signal = transcriptions_updated_signal
+
+ self.identification_result = None
+
+ self.thread = None
+ self.worker = None
+
+ self.setMinimumWidth(650)
+ self.setMinimumHeight(400)
+
+ self.setWindowTitle(file_path_as_title(transcription.file))
+
+ layout = QFormLayout(self)
+ layout.setSizeConstraint(QLayout.SizeConstraint.SetMinAndMaxSize)
+
+ # Step 1: Identify speakers
+ step_1_label = QLabel(_("Step 1: Identify speakers"), self)
+ font = step_1_label.font()
+ font.setWeight(QFont.Weight.Bold)
+ step_1_label.setFont(font)
+ layout.addRow(step_1_label)
+
+ step_1_group_box = QGroupBox(self)
+ step_1_group_box.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
+ step_1_layout = QVBoxLayout(step_1_group_box)
+
+ self.step_1_row = QHBoxLayout()
+
+ self.step_1_button = QPushButton(_("Identify"))
+ self.step_1_button.setMinimumWidth(200)
+ self.step_1_button.clicked.connect(self.on_identify_button_clicked)
+
+ # Progress container with label and bar
+ progress_container = QVBoxLayout()
+
+ self.progress_label = QLabel(self)
+ if os.path.isfile(self.transcription.file):
+ self.progress_label.setText(_("Ready to identify speakers"))
+ else:
+ self.progress_label.setText(_("Audio file not found"))
+ self.step_1_button.setEnabled(False)
+
+ self.progress_bar = QProgressBar(self)
+ self.progress_bar.setMinimumWidth(400)
+ self.progress_bar.setRange(0, 8)
+ self.progress_bar.setValue(0)
+
+ progress_container.addWidget(self.progress_label)
+ progress_container.addWidget(self.progress_bar)
+
+ self.step_1_row.addLayout(progress_container)
+
+ self.step_1_row.addWidget(self.step_1_button, alignment=Qt.AlignmentFlag.AlignTop)
+
+ step_1_layout.addLayout(self.step_1_row)
+
+ layout.addRow(step_1_group_box)
+
+ # Spacer
+ spacer = QSpacerItem(0, 10, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Fixed)
+ layout.addItem(spacer)
+
+ # Step 2: Name speakers
+ step_2_label = QLabel(_("Step 2: Name speakers"), self)
+ font = step_2_label.font()
+ font.setWeight(QFont.Weight.Bold)
+ step_2_label.setFont(font)
+ layout.addRow(step_2_label)
+
+ self.step_2_group_box = QGroupBox(self)
+ self.step_2_group_box.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
+ self.step_2_group_box.setEnabled(False)
+ step_2_layout = QVBoxLayout(self.step_2_group_box)
+
+ self.speaker_preview_row = QVBoxLayout()
+
+ self.speaker_0_input = LineEdit("Speaker 0", self)
+
+ self.speaker_0_preview_button = QPushButton(_("Play sample"))
+ self.speaker_0_preview_button.setMinimumWidth(200)
+ self.speaker_0_preview_button.clicked.connect(lambda: self.on_speaker_preview("Speaker 0"))
+
+ speaker_0_layout = QHBoxLayout()
+ speaker_0_layout.addWidget(self.speaker_0_input)
+ speaker_0_layout.addWidget(self.speaker_0_preview_button)
+
+ self.speaker_preview_row.addLayout(speaker_0_layout)
+
+ step_2_layout.addLayout(self.speaker_preview_row)
+
+ layout.addRow(self.step_2_group_box)
+
+ # Save button
+ self.merge_speaker_sentences = QCheckBox(_("Merge speaker sentences"))
+ self.merge_speaker_sentences.setChecked(True)
+ self.merge_speaker_sentences.setEnabled(False)
+ self.merge_speaker_sentences.setMinimumWidth(250)
+
+ self.save_button = QPushButton(_("Save"))
+ self.save_button.setEnabled(False)
+ self.save_button.clicked.connect(self.on_save_button_clicked)
+
+ layout.addRow(self.merge_speaker_sentences)
+ layout.addRow(self.save_button)
+
+ self.setLayout(layout)
+
+ # Invisible preview player
+ url = QUrl.fromLocalFile(self.transcription.file)
+ self.player = QMediaPlayer()
+ self.audio_output = QAudioOutput()
+ self.player.setAudioOutput(self.audio_output)
+ self.player.setSource(url)
+ self.player_timer = None
+
+ def on_identify_button_clicked(self):
+ self.step_1_button.setEnabled(False)
+
+ self.thread = QThread()
+ self.worker = IdentificationWorker(
+ self.transcription,
+ self.transcription_service
+ )
+ self.worker.moveToThread(self.thread)
+ self.thread.started.connect(self.worker.run)
+ self.worker.finished.connect(self.thread.quit)
+ self.worker.finished.connect(self.worker.deleteLater)
+ self.thread.finished.connect(self.thread.deleteLater)
+ self.worker.finished.connect(self.on_identification_finished)
+ self.worker.progress_update.connect(self.on_progress_update)
+
+ self.thread.start()
+
+ def on_progress_update(self, progress):
+ self.progress_label.setText(progress)
+
+ progress_value = 0
+ if progress and progress[0].isdigit():
+ progress_value = int(progress[0])
+ self.progress_bar.setValue(progress_value)
+ else:
+ logging.error(f"Invalid progress format: {progress}")
+
+ if progress_value == 8:
+ self.step_2_group_box.setEnabled(True)
+ self.merge_speaker_sentences.setEnabled(True)
+ self.save_button.setEnabled(True)
+
+ def on_identification_finished(self, result):
+ self.identification_result = result
+
+ unique_speakers = {entry['speaker'] for entry in result}
+
+ while self.speaker_preview_row.count():
+ item = self.speaker_preview_row.takeAt(0)
+ widget = item.widget()
+ if widget:
+ widget.deleteLater()
+ else:
+ layout = item.layout()
+ if layout:
+ while layout.count():
+ sub_item = layout.takeAt(0)
+ sub_widget = sub_item.widget()
+ if sub_widget:
+ sub_widget.deleteLater()
+
+ for speaker in sorted(unique_speakers):
+ speaker_input = LineEdit(speaker, self)
+ speaker_input.setMinimumWidth(200)
+
+ speaker_preview_button = QPushButton(_("Play sample"))
+ speaker_preview_button.setMinimumWidth(200)
+ speaker_preview_button.clicked.connect(lambda checked, s=speaker: self.on_speaker_preview(s))
+
+ speaker_layout = QHBoxLayout()
+ speaker_layout.addWidget(speaker_input)
+ speaker_layout.addWidget(speaker_preview_button)
+
+ self.speaker_preview_row.addLayout(speaker_layout)
+
+ def on_speaker_preview(self, speaker_id):
+ if self.player_timer:
+ self.player_timer.stop()
+
+ speaker_records = [record for record in self.identification_result if record['speaker'] == speaker_id]
+
+ if speaker_records:
+ random_record = random.choice(speaker_records)
+
+ start_time = random_record['start_time']
+ end_time = random_record['end_time']
+
+ self.player.setPosition(int(start_time))
+ self.player.play()
+
+ self.player_timer = QTimer(self)
+ self.player_timer.setSingleShot(True)
+ self.player_timer.timeout.connect(self.player.stop)
+ self.player_timer.start(min(end_time, 10 * 1000)) # 10 seconds
+
+ def on_save_button_clicked(self):
+ speaker_names = []
+ for i in range(self.speaker_preview_row.count()):
+ item = self.speaker_preview_row.itemAt(i)
+ if item.layout():
+ for j in range(item.layout().count()):
+ sub_item = item.layout().itemAt(j)
+ widget = sub_item.widget()
+ if isinstance(widget, LineEdit):
+ speaker_names.append(widget.text())
+
+ unique_speakers = {entry['speaker'] for entry in self.identification_result}
+ original_speakers = sorted(unique_speakers)
+ speaker_mapping = dict(zip(original_speakers, speaker_names))
+
+ segments = []
+ if self.merge_speaker_sentences.isChecked():
+ previous_segment = None
+
+ for entry in self.identification_result:
+ speaker_name = speaker_mapping.get(entry['speaker'], entry['speaker'])
+
+ if previous_segment and previous_segment['speaker'] == speaker_name:
+ previous_segment['end_time'] = entry['end_time']
+ previous_segment['text'] += " " + entry['text']
+ else:
+ if previous_segment:
+ segment = Segment(
+ start=previous_segment['start_time'],
+ end=previous_segment['end_time'],
+ text=f"{previous_segment['speaker']}: {previous_segment['text']}"
+ )
+ segments.append(segment)
+ previous_segment = {
+ 'start_time': entry['start_time'],
+ 'end_time': entry['end_time'],
+ 'speaker': speaker_name,
+ 'text': entry['text']
+ }
+
+ if previous_segment:
+ segment = Segment(
+ start=previous_segment['start_time'],
+ end=previous_segment['end_time'],
+ text=f"{previous_segment['speaker']}: {previous_segment['text']}"
+ )
+ segments.append(segment)
+ else:
+ for entry in self.identification_result:
+ speaker_name = speaker_mapping.get(entry['speaker'], entry['speaker'])
+ segment = Segment(
+ start=entry['start_time'],
+ end=entry['end_time'],
+ text=f"{speaker_name}: {entry['text']}"
+ )
+ segments.append(segment)
+
+ new_transcript_id = self.transcription_service.copy_transcription(
+ self.transcription.id_as_uuid
+ )
+
+ self.transcription_service.update_transcription_as_completed(new_transcript_id, segments)
+
+ # TODO - See if we can get rows in the transcription viewer to be of variable height
+ # If text is longer they should expand
+ if self.transcriptions_updated_signal:
+ self.transcriptions_updated_signal.emit(new_transcript_id)
+
+ self.player.stop()
+
+ if self.player_timer:
+ self.player_timer.stop()
+
+ self.close()
+
+ def closeEvent(self, event):
+ self.hide()
+
+ super().closeEvent(event)
diff --git a/buzz/widgets/transcription_viewer/transcription_resizer_widget.py b/buzz/widgets/transcription_viewer/transcription_resizer_widget.py
index a873eb0c..cb8dfcfc 100644
--- a/buzz/widgets/transcription_viewer/transcription_resizer_widget.py
+++ b/buzz/widgets/transcription_viewer/transcription_resizer_widget.py
@@ -37,8 +37,7 @@ from buzz.widgets.preferences_dialog.models.file_transcription_preferences impor
SENTENCE_END = re.compile(r'.*[.!?。!?]')
class TranscriptionWorker(QObject):
- finished = pyqtSignal()
- result_ready = pyqtSignal(list)
+ finished = pyqtSignal(list)
def __init__(self, transcription, transcription_options, transcription_service, regroup_string: str):
super().__init__()
@@ -85,7 +84,7 @@ class TranscriptionWorker(QObject):
if self.transcription_options.extract_speech and os.path.exists(speech_path):
transcription_file = str(speech_path)
transcription_file_exists = True
- # TODO - Fix VAD and Silence suppression that fails to work/download VAd model in compilded form on Mac and Windows
+ # TODO - Fix VAD and Silence suppression that fails to work/download Vad model in compilded form on Mac and Windows
try:
result = stable_whisper.transcribe_any(
@@ -113,8 +112,7 @@ class TranscriptionWorker(QObject):
)
)
- self.result_ready.emit(segments)
- self.finished.emit()
+ self.finished.emit(segments)
class TranscriptionResizerWidget(QWidget):
@@ -336,7 +334,7 @@ class TranscriptionResizerWidget(QWidget):
self.worker.finished.connect(self.thread.quit)
self.worker.finished.connect(self.worker.deleteLater)
self.thread.finished.connect(self.thread.deleteLater)
- self.worker.result_ready.connect(self.on_transcription_completed)
+ self.worker.finished.connect(self.on_transcription_completed)
self.thread.start()
diff --git a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
index e77c2179..51b4e67c 100644
--- a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
+++ b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py
@@ -1,5 +1,6 @@
import os
import logging
+import platform
from typing import Optional
from uuid import UUID
@@ -38,6 +39,7 @@ from buzz.widgets.icon import (
ResizeIcon,
ScrollToCurrentIcon,
VisibilityIcon,
+ SpeakerIdentificationIcon,
)
from buzz.translator import Translator
from buzz.widgets.text_display_box import TextDisplayBox
@@ -59,6 +61,10 @@ from buzz.widgets.transcription_viewer.transcription_view_mode_tool_button impor
)
from buzz.widgets.transcription_viewer.transcription_resizer_widget import TranscriptionResizerWidget
+# Underlying libs do not support intel Macs
+if not (platform.system() == "Darwin" and platform.machine() == "x86_64"):
+ from buzz.widgets.transcription_viewer.speaker_identification_widget import SpeakerIdentificationWidget
+
class TranscriptionViewerWidget(QWidget):
resize_button_clicked = pyqtSignal()
@@ -85,6 +91,7 @@ class TranscriptionViewerWidget(QWidget):
self.setWindowTitle(file_path_as_title(transcription.file))
self.transcription_resizer_dialog = None
+ self.speaker_identification_dialog = None
self.transcriptions_updated_signal = transcriptions_updated_signal
self.translation_thread = None
@@ -98,7 +105,7 @@ class TranscriptionViewerWidget(QWidget):
# Loop functionality
self.segment_looping_enabled = self.settings.settings.value("transcription_viewer/segment_looping_enabled", False, type=bool)
-
+
# UI visibility preferences
self.playback_controls_visible = self.settings.settings.value("transcription_viewer/playback_controls_visible", False, type=bool)
self.find_widget_visible = self.settings.settings.value("transcription_viewer/find_widget_visible", False, type=bool)
@@ -165,18 +172,18 @@ class TranscriptionViewerWidget(QWidget):
# Create a better current segment display that handles long text
self.current_segment_frame = QFrame()
self.current_segment_frame.setFrameStyle(QFrame.Shape.NoFrame)
-
+
segment_layout = QVBoxLayout(self.current_segment_frame)
segment_layout.setContentsMargins(4, 4, 4, 4) # Minimal margins for clean appearance
segment_layout.setSpacing(0) # No spacing between elements
-
+
# Text display - centered with scroll capability (no header label)
self.current_segment_text = QLabel("")
self.current_segment_text.setAlignment(Qt.AlignmentFlag.AlignHCenter | Qt.AlignmentFlag.AlignTop)
self.current_segment_text.setWordWrap(True)
self.current_segment_text.setStyleSheet("color: #666; line-height: 1.2; margin: 0; padding: 4px;")
self.current_segment_text.setMinimumHeight(60) # Ensure minimum height for text
-
+
# Make it scrollable for long text
self.current_segment_scroll_area = QScrollArea()
self.current_segment_scroll_area.setWidget(self.current_segment_text)
@@ -185,13 +192,13 @@ class TranscriptionViewerWidget(QWidget):
self.current_segment_scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.current_segment_scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
self.current_segment_scroll_area.setStyleSheet("QScrollBar:vertical { width: 12px; } QScrollBar::handle:vertical { background: #ccc; border-radius: 6px; }")
-
+
# Ensure the text label can expand to show all content
self.current_segment_text.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred)
-
+
# Add scroll area to layout (simplified single-widget layout)
segment_layout.addWidget(self.current_segment_scroll_area)
-
+
# Initially hide the frame until there's content
self.current_segment_frame.hide()
@@ -247,6 +254,19 @@ class TranscriptionViewerWidget(QWidget):
toolbar.addWidget(resize_button)
+ # Underlying libs do not support intel Macs
+ if not (platform.system() == "Darwin" and platform.machine() == "x86_64"):
+ speaker_identification_button = QToolButton()
+ speaker_identification_button.setText(_("Identify Speakers"))
+ speaker_identification_button.setObjectName("speaker_identification_button")
+ speaker_identification_button.setIcon(SpeakerIdentificationIcon(self))
+ speaker_identification_button.setToolButtonStyle(
+ Qt.ToolButtonStyle.ToolButtonTextBesideIcon
+ )
+ speaker_identification_button.clicked.connect(self.on_speaker_identification_button_clicked)
+
+ toolbar.addWidget(speaker_identification_button)
+
# Add Find button
self.find_button = QToolButton()
self.find_button.setText(_("Find"))
@@ -267,14 +287,14 @@ class TranscriptionViewerWidget(QWidget):
# Table widget should take the majority of the space
layout.addWidget(self.table_widget, 1) # Stretch factor 1 (majority)
-
+
# Loop controls section (minimal space)
self.create_loop_controls()
layout.addWidget(self.loop_controls_frame, 0) # Stretch factor 0 (minimal)
-
+
# Audio player (minimal space)
layout.addWidget(self.audio_player, 0) # Stretch factor 0 (minimal)
-
+
# Text display box (minimal space)
layout.addWidget(self.text_display_box, 0) # Stretch factor 0 (minimal)
@@ -291,7 +311,7 @@ class TranscriptionViewerWidget(QWidget):
# Restore UI state from settings
self.restore_ui_state()
-
+
# Restore geometry from settings
self.load_geometry()
@@ -302,7 +322,7 @@ class TranscriptionViewerWidget(QWidget):
# Restore playback controls visibility
if self.playback_controls_visible:
self.show_loop_controls()
-
+
# Restore find widget visibility
if self.find_widget_visible:
self.show_search_bar()
@@ -312,28 +332,28 @@ class TranscriptionViewerWidget(QWidget):
self.search_frame = QFrame()
self.search_frame.setFrameStyle(QFrame.Shape.StyledPanel)
self.search_frame.setMaximumHeight(60)
-
+
search_layout = QHBoxLayout(self.search_frame)
search_layout.setContentsMargins(10, 5, 10, 5)
-
+
# Find label
search_label = QLabel(_("Find:"))
search_label.setStyleSheet("font-weight: bold;")
search_layout.addWidget(search_label)
-
+
# Find input - make it wider for better usability
self.search_input = QLineEdit()
self.search_input.setPlaceholderText(_("Enter text to find..."))
self.search_input.textChanged.connect(self.on_search_text_changed)
self.search_input.returnPressed.connect(self.search_next)
self.search_input.setMinimumWidth(300) # Increased from 200 to 300
-
+
# Add keyboard shortcuts for search navigation
from PyQt6.QtGui import QKeySequence
self.search_input.installEventFilter(self)
-
+
search_layout.addWidget(self.search_input)
-
+
# Search buttons - make them consistent height and remove hardcoded font sizes
self.search_prev_button = QPushButton("↑")
self.search_prev_button.setToolTip(_("Previous match (Shift+Enter)"))
@@ -342,7 +362,7 @@ class TranscriptionViewerWidget(QWidget):
self.search_prev_button.setMaximumWidth(40)
self.search_prev_button.setMinimumHeight(30) # Ensure consistent height
search_layout.addWidget(self.search_prev_button)
-
+
self.search_next_button = QPushButton("↓")
self.search_next_button.setToolTip(_("Next match (Enter)"))
self.search_next_button.clicked.connect(self.search_next)
@@ -350,21 +370,21 @@ class TranscriptionViewerWidget(QWidget):
self.search_next_button.setMaximumWidth(40)
self.search_next_button.setMinimumHeight(30) # Ensure consistent height
search_layout.addWidget(self.search_next_button)
-
+
# Clear button - make it bigger to accommodate different language translations
self.clear_search_button = QPushButton(_("Clear"))
self.clear_search_button.clicked.connect(self.clear_search)
self.clear_search_button.setMaximumWidth(80) # Increased from 60 to 80
self.clear_search_button.setMinimumHeight(30) # Ensure consistent height
search_layout.addWidget(self.clear_search_button)
-
+
# Results label
self.search_results_label = QLabel("")
self.search_results_label.setStyleSheet("color: #666;")
search_layout.addWidget(self.search_results_label)
-
+
search_layout.addStretch()
-
+
# Initially hide the search bar
self.search_frame.hide()
@@ -373,23 +393,23 @@ class TranscriptionViewerWidget(QWidget):
self.loop_controls_frame = QFrame()
self.loop_controls_frame.setFrameStyle(QFrame.Shape.StyledPanel)
self.loop_controls_frame.setMaximumHeight(50)
-
+
loop_layout = QHBoxLayout(self.loop_controls_frame)
loop_layout.setContentsMargins(10, 5, 10, 5)
loop_layout.setSpacing(8) # Add some spacing between elements for better visual separation
-
+
# Loop controls label
loop_label = QLabel(_("Playback Controls:"))
loop_label.setStyleSheet("font-weight: bold;")
loop_layout.addWidget(loop_label)
-
+
# Loop toggle button
self.loop_toggle = QCheckBox(_("Loop Segment"))
self.loop_toggle.setChecked(self.segment_looping_enabled)
self.loop_toggle.setToolTip(_("Enable/disable looping when clicking on transcript segments"))
self.loop_toggle.toggled.connect(self.on_loop_toggle_changed)
loop_layout.addWidget(self.loop_toggle)
-
+
# Follow audio toggle button
self.follow_audio_enabled = self.settings.settings.value("transcription_viewer/follow_audio_enabled", False, type=bool)
self.follow_audio_toggle = QCheckBox(_("Follow Audio"))
@@ -397,19 +417,19 @@ class TranscriptionViewerWidget(QWidget):
self.follow_audio_toggle.setToolTip(_("Enable/disable following the current audio position in the transcript. When enabled, automatically scrolls to current text."))
self.follow_audio_toggle.toggled.connect(self.on_follow_audio_toggle_changed)
loop_layout.addWidget(self.follow_audio_toggle)
-
+
# Visual separator
separator1 = QFrame()
separator1.setFrameShape(QFrame.Shape.VLine)
separator1.setFrameShadow(QFrame.Shadow.Sunken)
separator1.setMaximumHeight(20)
loop_layout.addWidget(separator1)
-
+
# Speed controls
speed_label = QLabel("Speed:")
speed_label.setStyleSheet("font-weight: bold;")
loop_layout.addWidget(speed_label)
-
+
self.speed_combo = QComboBox()
self.speed_combo.setEditable(True)
self.speed_combo.addItems(["0.5x", "0.75x", "1x", "1.25x", "1.5x", "2x"])
@@ -417,29 +437,29 @@ class TranscriptionViewerWidget(QWidget):
self.speed_combo.currentTextChanged.connect(self.on_speed_changed)
self.speed_combo.setMaximumWidth(80)
loop_layout.addWidget(self.speed_combo)
-
+
self.speed_down_btn = QPushButton("-")
self.speed_down_btn.setMaximumWidth(40) # Match search button width
self.speed_down_btn.setMinimumHeight(30) # Match search button height
self.speed_down_btn.clicked.connect(self.decrease_speed)
loop_layout.addWidget(self.speed_down_btn)
-
+
self.speed_up_btn = QPushButton("+")
self.speed_up_btn.setMaximumWidth(40) # Match speed down button width
self.speed_up_btn.setMinimumHeight(30) # Match search button height
self.speed_up_btn.clicked.connect(self.increase_speed)
loop_layout.addWidget(self.speed_up_btn)
-
+
# Initialize speed control with current value from audio player
self.initialize_speed_control()
-
+
# Visual separator
separator2 = QFrame()
separator2.setFrameShape(QFrame.Shape.VLine)
separator2.setFrameShadow(QFrame.Shadow.Sunken)
separator2.setMaximumHeight(20)
loop_layout.addWidget(separator2)
-
+
# Scroll to current button
self.scroll_to_current_button = QPushButton(_("Scroll to Current"))
self.scroll_to_current_button.setIcon(ScrollToCurrentIcon(self))
@@ -448,16 +468,16 @@ class TranscriptionViewerWidget(QWidget):
self.scroll_to_current_button.setMinimumHeight(30)
self.scroll_to_current_button.setStyleSheet("QPushButton { padding: 4px 8px; }") # Better padding
loop_layout.addWidget(self.scroll_to_current_button)
-
+
loop_layout.addStretch()
-
+
# Initially hide the loop controls frame
self.loop_controls_frame.hide()
def show_loop_controls(self):
"""Show the loop controls when audio is playing"""
self.loop_controls_frame.show()
-
+
# Save the visibility state to settings
self.playback_controls_visible = True
self.settings.settings.setValue("transcription_viewer/playback_controls_visible", self.playback_controls_visible)
@@ -465,7 +485,7 @@ class TranscriptionViewerWidget(QWidget):
def hide_loop_controls(self):
"""Hide the loop controls when audio is not playing"""
self.loop_controls_frame.hide()
-
+
# Save the visibility state to settings
self.playback_controls_visible = False
self.settings.settings.setValue("transcription_viewer/playback_controls_visible", self.playback_controls_visible)
@@ -600,7 +620,7 @@ class TranscriptionViewerWidget(QWidget):
def on_audio_playback_state_changed(self, state):
"""Handle audio playback state changes to automatically show/hide playback controls"""
from PyQt6.QtMultimedia import QMediaPlayer
-
+
if state == QMediaPlayer.PlaybackState.PlayingState:
# Show playback controls when audio starts playing
if self.view_mode == ViewMode.TIMESTAMPS:
@@ -630,25 +650,25 @@ class TranscriptionViewerWidget(QWidget):
# Extract the numeric value from speed text (e.g., "1.5x" -> 1.5)
clean_text = speed_text.replace('x', '').strip()
speed_value = float(clean_text)
-
+
# Clamp the speed value to valid range
speed_value = max(0.1, min(5.0, speed_value))
-
+
# Update the combo box text to show the clamped value
if not speed_text.endswith('x'):
speed_text = f"{speed_value:.2f}x"
-
+
# Block signals to prevent recursion
self.speed_combo.blockSignals(True)
self.speed_combo.setCurrentText(speed_text)
self.speed_combo.blockSignals(False)
-
+
# Set the playback rate on the audio player
self.audio_player.media_player.setPlaybackRate(speed_value)
-
+
# Save the new rate to settings
self.settings.set_value(self.settings.Key.AUDIO_PLAYBACK_RATE, speed_value)
-
+
except ValueError:
logging.warning(f"Invalid speed value: {speed_text}")
# Reset to current valid value
@@ -680,14 +700,14 @@ class TranscriptionViewerWidget(QWidget):
"""Set the playback speed programmatically"""
# Clamp the speed value to valid range
speed = max(0.1, min(5.0, speed))
-
+
# Update the combo box
speed_text = f"{speed:.2f}x"
self.speed_combo.setCurrentText(speed_text)
-
+
# Set the playback rate on the audio player
self.audio_player.media_player.setPlaybackRate(speed)
-
+
# Save the new rate to settings
self.settings.set_value(self.settings.Key.AUDIO_PLAYBACK_RATE, speed)
@@ -707,49 +727,49 @@ class TranscriptionViewerWidget(QWidget):
"""Perform the actual search based on current view mode"""
self.search_results = []
self.current_search_index = 0
-
+
if self.view_mode == ViewMode.TIMESTAMPS:
self.search_in_table()
else: # TEXT or TRANSLATION mode
self.search_in_text()
-
+
self.update_search_ui()
def search_in_table(self):
"""Search in the table view (segments)"""
segments = self.table_widget.segments()
search_text_lower = self.search_text.lower()
-
+
# Limit search results to avoid performance issues with very long segments
max_results = 100
-
+
for i, segment in enumerate(segments):
if len(self.search_results) >= max_results:
break
-
+
text = segment.value("text").lower()
if search_text_lower in text:
self.search_results.append(("table", i, segment))
-
+
# Also search in translations if available
if self.has_translations:
for i, segment in enumerate(segments):
if len(self.search_results) >= max_results:
break
-
+
translation = segment.value("translation").lower()
if search_text_lower in translation:
- self.search_results.append(("table", i, segment))
+ self.search_results.append(("table", i, segment))
def search_in_text(self):
"""Search in the text display box"""
text = self.text_display_box.toPlainText()
search_text_lower = self.search_text.lower()
text_lower = text.lower()
-
+
# Limit search results to avoid performance issues with very long text
max_results = 100
-
+
start = 0
result_count = 0
while True:
@@ -780,9 +800,9 @@ class TranscriptionViewerWidget(QWidget):
"""Highlight the current search match"""
if not self.search_results:
return
-
+
match_type, match_data, _ = self.search_results[self.current_search_index]
-
+
if match_type == "table":
# Highlight in table
self.highlight_table_match(match_data)
@@ -802,10 +822,10 @@ class TranscriptionViewerWidget(QWidget):
cursor = QTextCursor(self.text_display_box.document())
cursor.setPosition(start_pos)
cursor.setPosition(start_pos + len(self.search_text), QTextCursor.MoveMode.KeepAnchor)
-
+
# Set the cursor to highlight the text
self.text_display_box.setTextCursor(cursor)
-
+
# Ensure the highlighted text is visible
self.text_display_box.ensureCursorVisible()
@@ -813,7 +833,7 @@ class TranscriptionViewerWidget(QWidget):
"""Go to next search result"""
if not self.search_results:
return
-
+
self.current_search_index = (self.current_search_index + 1) % len(self.search_results)
self.highlight_current_match()
self.update_search_results_label()
@@ -822,7 +842,7 @@ class TranscriptionViewerWidget(QWidget):
"""Go to previous search result"""
if not self.search_results:
return
-
+
self.current_search_index = (self.current_search_index - 1) % len(self.search_results)
self.highlight_current_match()
self.update_search_results_label()
@@ -845,13 +865,13 @@ class TranscriptionViewerWidget(QWidget):
self.search_prev_button.setEnabled(False)
self.search_next_button.setEnabled(False)
-
+
# Clear text highlighting
if self.view_mode in (ViewMode.TEXT, ViewMode.TRANSLATION):
cursor = QTextCursor(self.text_display_box.document())
cursor.clearSelection()
self.text_display_box.setTextCursor(cursor)
-
+
# Keep search bar visible but clear the input
self.search_input.setFocus()
@@ -861,7 +881,7 @@ class TranscriptionViewerWidget(QWidget):
self.find_button.setChecked(False) # Sync button state
self.clear_search()
self.search_input.clearFocus()
-
+
# Save the visibility state to settings
self.find_widget_visible = False
self.settings.settings.setValue("transcription_viewer/find_widget_visible", False)
@@ -869,11 +889,11 @@ class TranscriptionViewerWidget(QWidget):
def setup_shortcuts(self):
"""Set up keyboard shortcuts"""
from PyQt6.QtGui import QShortcut, QKeySequence
-
+
# Search shortcut (Ctrl+F)
search_shortcut = QShortcut(QKeySequence(self.shortcuts.get(Shortcut.SEARCH_TRANSCRIPT)), self)
search_shortcut.activated.connect(self.focus_search_input)
-
+
# Scroll to current text shortcut (Ctrl+G)
scroll_to_current_shortcut = QShortcut(QKeySequence(self.shortcuts.get(Shortcut.SCROLL_TO_CURRENT_TEXT)), self)
scroll_to_current_shortcut.activated.connect(self.on_scroll_to_current_button_clicked)
@@ -912,7 +932,7 @@ class TranscriptionViewerWidget(QWidget):
self.find_button.setChecked(True) # Sync button state
self.search_input.setFocus()
self.search_input.selectAll()
-
+
# Save the visibility state to settings
self.find_widget_visible = True
self.settings.settings.setValue("transcription_viewer/find_widget_visible", True)
@@ -923,7 +943,7 @@ class TranscriptionViewerWidget(QWidget):
self.hide_search_bar()
else:
self.show_search_bar()
-
+
# Save the visibility state to settings
self.find_widget_visible = self.search_frame.isVisible()
self.settings.settings.setValue("transcription_viewer/find_widget_visible", self.find_widget_visible)
@@ -934,7 +954,7 @@ class TranscriptionViewerWidget(QWidget):
self.find_button.setChecked(True)
self.search_input.setFocus()
self.search_input.selectAll()
-
+
# Save the visibility state to settings
self.find_widget_visible = True
self.settings.settings.setValue("transcription_viewer/find_widget_visible", True)
@@ -942,7 +962,7 @@ class TranscriptionViewerWidget(QWidget):
def eventFilter(self, obj, event):
"""Event filter to handle keyboard shortcuts in search input"""
from PyQt6.QtCore import QEvent, Qt
-
+
if obj == self.search_input and event.type() == QEvent.Type.KeyPress:
# The event is already a QKeyEvent, no need to create a new one
if event.key() == Qt.Key.Key_Return and event.modifiers() == Qt.KeyboardModifier.ShiftModifier:
@@ -999,7 +1019,7 @@ class TranscriptionViewerWidget(QWidget):
self.loop_controls_frame.hide()
# Hide current segment display in translation mode
self.current_segment_frame.hide()
-
+
# Refresh search if there's active search text
if self.search_text:
self.perform_search()
@@ -1007,7 +1027,7 @@ class TranscriptionViewerWidget(QWidget):
def on_view_mode_changed(self, view_mode: ViewMode) -> None:
self.view_mode = view_mode
self.reset_view()
-
+
# Refresh search if there's active search text
if self.search_text:
self.perform_search()
@@ -1091,17 +1111,17 @@ class TranscriptionViewerWidget(QWidget):
if current_segment is not None:
self.current_segment_text.setText(current_segment.value("text"))
self.current_segment_frame.show() # Show the frame when there's a current segment
-
+
# Force the text label to recalculate its size
self.current_segment_text.adjustSize()
-
+
# Resize the frame to fit the text content
self.resize_current_segment_frame()
-
+
# Ensure the scroll area updates properly and shows scrollbars when needed
self.current_segment_scroll_area.updateGeometry()
self.current_segment_scroll_area.verticalScrollBar().setVisible(True) # Ensure scrollbar is visible
-
+
# Update highlighting based on follow audio and loop settings
if self.follow_audio_enabled:
# Follow audio mode: highlight the current segment based on audio position
@@ -1143,30 +1163,30 @@ class TranscriptionViewerWidget(QWidget):
# Calculate the height needed for the text area
line_height = self.current_segment_text.fontMetrics().lineSpacing()
max_visible_lines = 3 # Fixed at 3 lines for consistency and clean UI
-
+
# Calculate the height needed for the maximum visible lines (25% larger)
text_height = line_height * max_visible_lines * 1.25
-
+
# Add some vertical margins/padding
margins = 8 # Increased from 2 to 8 for better spacing
-
+
# Calculate total height needed (no header height anymore)
total_height = text_height + margins
-
+
# Convert to integer since Qt methods expect int values
total_height = int(total_height)
-
+
# Set maximum height to ensure consistent sizing, but allow minimum to be flexible
self.current_segment_frame.setMaximumHeight(total_height)
self.current_segment_frame.setMinimumHeight(total_height)
-
+
# Convert text_height to integer since Qt methods expect int values
text_height = int(text_height)
-
+
# Allow the scroll area to be flexible in height for proper scrolling
self.current_segment_scroll_area.setMinimumHeight(text_height)
self.current_segment_scroll_area.setMaximumHeight(text_height)
-
+
# Allow the text label to size naturally for proper scrolling
self.current_segment_text.setMinimumHeight(text_height)
@@ -1220,12 +1240,27 @@ class TranscriptionViewerWidget(QWidget):
self.transcription_resizer_dialog.show()
+ def on_speaker_identification_button_clicked(self):
+ # Underlying libs do not support intel Macs
+ if not (platform.system() == "Darwin" and platform.machine() == "x86_64"):
+ self.speaker_identification_dialog = SpeakerIdentificationWidget(
+ transcription=self.transcription,
+ transcription_service=self.transcription_service,
+ transcriptions_updated_signal=self.transcriptions_updated_signal,
+ )
+
+ self.transcriptions_updated_signal.connect(self.close)
+
+ self.speaker_identification_dialog.show()
+
+ pass
+
def on_loop_toggle_changed(self, enabled: bool):
"""Handle loop toggle state change"""
self.segment_looping_enabled = enabled
# Save preference to settings
self.settings.settings.setValue("transcription_viewer/segment_looping_enabled", enabled)
-
+
if enabled:
# If looping is re-enabled and we have a selected segment, return to it
if self.currently_selected_segment is not None:
@@ -1235,21 +1270,21 @@ class TranscriptionViewerWidget(QWidget):
if segment.value("id") == self.currently_selected_segment.value("id"):
# Highlight and scroll to the selected segment
self.table_widget.highlight_and_scroll_to_row(i)
-
+
# Get the segment timing
start_time = self.currently_selected_segment.value("start_time")
end_time = self.currently_selected_segment.value("end_time")
-
+
# Set the loop range for the selected segment
self.audio_player.set_range((start_time, end_time))
-
+
# If audio is currently playing and outside the range, jump to the start
current_pos = self.audio_player.position_ms
playback_state = self.audio_player.media_player.playbackState()
- if (playback_state == QMediaPlayer.PlaybackState.PlayingState and
+ if (playback_state == QMediaPlayer.PlaybackState.PlayingState and
(current_pos < start_time or current_pos > end_time)):
self.audio_player.set_position(start_time)
-
+
break
else:
# Clear any existing range if looping is disabled
@@ -1260,7 +1295,7 @@ class TranscriptionViewerWidget(QWidget):
self.follow_audio_enabled = enabled
# Save preference to settings
self.settings.settings.setValue("transcription_viewer/follow_audio_enabled", enabled)
-
+
if enabled:
# When follow audio is first enabled, automatically scroll to current position
# This gives immediate feedback that the feature is working
@@ -1310,17 +1345,17 @@ class TranscriptionViewerWidget(QWidget):
# Only scroll if we're in timestamps view mode (table is visible)
if self.view_mode != ViewMode.TIMESTAMPS:
return
-
+
current_pos = self.audio_player.position_ms
segments = self.table_widget.segments()
-
+
# Find the current segment based on audio position
current_segment = next(
- (segment for segment in segments
+ (segment for segment in segments
if segment.value("start_time") <= current_pos < segment.value("end_time")),
None
)
-
+
if current_segment is not None:
# Find the row index and scroll to it
for i, segment in enumerate(segments):
@@ -1329,7 +1364,7 @@ class TranscriptionViewerWidget(QWidget):
# Method 1: Use the table widget's built-in scrolling method
self.table_widget.highlight_and_scroll_to_row(i)
break
-
+
except Exception as e:
pass # Silently handle any errors
@@ -1346,6 +1381,9 @@ class TranscriptionViewerWidget(QWidget):
if self.transcription_resizer_dialog:
self.transcription_resizer_dialog.close()
+ if self.speaker_identification_dialog:
+ self.speaker_identification_dialog.close()
+
self.translator.stop()
self.translation_thread.quit()
diff --git a/ctc_forced_aligner b/ctc_forced_aligner
new file mode 160000
index 00000000..1f0a5f86
--- /dev/null
+++ b/ctc_forced_aligner
@@ -0,0 +1 @@
+Subproject commit 1f0a5f860d3d9daf3d94edb1c7d18f90d1702e5b
diff --git a/deepmultilingualpunctuation b/deepmultilingualpunctuation
new file mode 160000
index 00000000..5a0dd7f4
--- /dev/null
+++ b/deepmultilingualpunctuation
@@ -0,0 +1 @@
+Subproject commit 5a0dd7f4fd56687f59405aa8eba1144393d8b74b
diff --git a/demucs/.github/ISSUE_TEMPLATE/bug.md b/demucs/.github/ISSUE_TEMPLATE/bug.md
new file mode 100644
index 00000000..217654a9
--- /dev/null
+++ b/demucs/.github/ISSUE_TEMPLATE/bug.md
@@ -0,0 +1,33 @@
+---
+name: 🐛 Bug Report
+about: Submit a bug report to help us improve
+labels: 'bug'
+---
+
+## 🐛 Bug Report
+
+(A clear and concise description of what the bug is)
+
+## To Reproduce
+
+(Write your steps here:)
+
+1. Step 1...
+1. Step 2...
+1. Step 3...
+
+## Expected behavior
+
+(Write what you thought would happen.)
+
+## Actual Behavior
+
+(Write what happened. Add screenshots, if applicable.)
+
+## Your Environment
+
+
+
+- Python and PyTorch version:
+- Operating system and version (desktop or mobile):
+- Hardware (gpu or cpu, amount of RAM etc.):
diff --git a/demucs/.github/ISSUE_TEMPLATE/question.md b/demucs/.github/ISSUE_TEMPLATE/question.md
new file mode 100644
index 00000000..85a007e4
--- /dev/null
+++ b/demucs/.github/ISSUE_TEMPLATE/question.md
@@ -0,0 +1,10 @@
+---
+name: "❓Questions/Help/Support"
+about: If you have a question about the paper, code or algorithm, please ask here!
+labels: question
+
+---
+
+## ❓ Questions
+
+(Please ask your question here.)
diff --git a/demucs/.github/workflows/linter.yml b/demucs/.github/workflows/linter.yml
new file mode 100644
index 00000000..64f235fb
--- /dev/null
+++ b/demucs/.github/workflows/linter.yml
@@ -0,0 +1,36 @@
+name: linter
+on:
+ push:
+ branches: [ main ]
+ pull_request:
+ branches: [ main ]
+ workflow_dispatch:
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }}
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/setup-python@v2
+ with:
+ python-version: 3.8
+
+ - uses: actions/cache@v2
+ with:
+ path: env
+ key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }}
+
+ - name: Install dependencies
+ run: |
+ python3 -m venv env
+ . env/bin/activate
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ pip install '.[dev]'
+
+
+ - name: Run linter
+ run: |
+ . env/bin/activate
+ make linter
diff --git a/demucs/.github/workflows/tests.yml b/demucs/.github/workflows/tests.yml
new file mode 100644
index 00000000..b31e3dd6
--- /dev/null
+++ b/demucs/.github/workflows/tests.yml
@@ -0,0 +1,36 @@
+name: tests
+on:
+ push:
+ branches: [ main ]
+ pull_request:
+ branches: [ main ]
+ workflow_dispatch:
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }}
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/setup-python@v2
+ with:
+ python-version: 3.8
+
+ - uses: actions/cache@v2
+ with:
+ path: env
+ key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }}
+
+ - name: Install dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y ffmpeg
+ python3 -m venv env
+ . env/bin/activate
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+
+ - name: Run separation test
+ run: |
+ . env/bin/activate
+ make test_eval
diff --git a/demucs/.gitignore b/demucs/.gitignore
new file mode 100644
index 00000000..179cf0dd
--- /dev/null
+++ b/demucs/.gitignore
@@ -0,0 +1,17 @@
+*.egg-info
+__pycache__
+Session.vim
+/build
+/dist
+/lab
+/metadata
+/notebooks
+/outputs
+/release
+/release_models
+/separated
+/tests
+/trash
+/misc
+/mdx
+.mypy_cache
diff --git a/demucs/CODE_OF_CONDUCT.md b/demucs/CODE_OF_CONDUCT.md
new file mode 100644
index 00000000..f049d4c5
--- /dev/null
+++ b/demucs/CODE_OF_CONDUCT.md
@@ -0,0 +1,76 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at . All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/demucs/CONTRIBUTING.md b/demucs/CONTRIBUTING.md
new file mode 100644
index 00000000..f14f4af3
--- /dev/null
+++ b/demucs/CONTRIBUTING.md
@@ -0,0 +1,23 @@
+# Contributing to Demucs
+
+## Pull Requests
+
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Facebook's open source projects.
+
+Complete your CLA here:
+
+Demucs is the implementation of a research paper.
+Therefore, we do not plan on accepting many pull requests for new features.
+We certainly welcome them for bug fixes.
+
+
+## Issues
+
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+
+## License
+By contributing to this repository, you agree that your contributions will be licensed
+under the LICENSE file in the root directory of this source tree.
diff --git a/demucs/Demucs.ipynb b/demucs/Demucs.ipynb
new file mode 100644
index 00000000..9ebcfd5a
--- /dev/null
+++ b/demucs/Demucs.ipynb
@@ -0,0 +1,153 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Be9yoh-ILfRr"
+ },
+ "source": [
+ "# Hybrid Demucs\n",
+ "\n",
+ "Feel free to use the Colab version:\n",
+ "https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 139
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 12277,
+ "status": "ok",
+ "timestamp": 1583778134659,
+ "user": {
+ "displayName": "Marllus Lustosa",
+ "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgLl2RbW64ZyWz3Y8IBku0zhHCMnt7fz7fEl0LTdA=s64",
+ "userId": "14811735256675200480"
+ },
+ "user_tz": 180
+ },
+ "id": "kOjIPLlzhPfn",
+ "outputId": "c75f17ec-b576-4105-bc5b-c2ac9c1018a3"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -U demucs\n",
+ "# or for local development, if you have a clone of Demucs\n",
+ "# pip install -e ."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "5lYOzKKCKAbJ"
+ },
+ "outputs": [],
+ "source": [
+ "# You can use the `demucs` command line to separate tracks\n",
+ "!demucs test.mp3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# You can also load directly the pretrained models,\n",
+ "# for instance for the MDX 2021 winning model of Track A:\n",
+ "from demucs import pretrained\n",
+ "model = pretrained.get_model('mdx')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Because `model` is a bag of 4 models, you cannot directly call it on your data,\n",
+ "# but the `apply_model` will know what to do of it.\n",
+ "import torch\n",
+ "from demucs.apply import apply_model\n",
+ "x = torch.randn(1, 2, 44100 * 10) # ten seconds of white noise for the demo\n",
+ "out = apply_model(model, x)[0] # shape is [S, C, T] with S the number of sources\n",
+ "\n",
+ "# So let see, where is all the white noise content is going ?\n",
+ "for name, source in zip(model.sources, out):\n",
+ " print(name, source.std() / x.std())\n",
+ "# The outputs are quite weird to be fair, not what I would have expected."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# now let's take a single model from the bag, and let's test it on a pure cosine\n",
+ "freq = 440 # in Hz\n",
+ "sr = model.samplerate\n",
+ "t = torch.arange(10 * sr).float() / sr\n",
+ "x = torch.cos(2 * 3.1416 * freq * t).expand(1, 2, -1)\n",
+ "sub_model = model.models[3]\n",
+ "out = sub_model(x)[0]\n",
+ "\n",
+ "# Same question where does it go?\n",
+ "for name, source in zip(model.sources, out):\n",
+ " print(name, source.std() / x.std())\n",
+ " \n",
+ "# Well now it makes much more sense, all the energy is going\n",
+ "# in the `other` source.\n",
+ "# Feel free to try lower pitch (try 80 Hz) to see what happens !"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# For training or more fun, refer to the Demucs README on our repo\n",
+ "# https://github.com/facebookresearch/demucs/tree/main/demucs"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "authorship_tag": "ABX9TyM9xpVr1M86NRcjtQ7g9tCx",
+ "collapsed_sections": [],
+ "name": "Demucs.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/demucs/LICENSE b/demucs/LICENSE
new file mode 100644
index 00000000..a45a376f
--- /dev/null
+++ b/demucs/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) Meta Platforms, Inc. and affiliates.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/demucs/MANIFEST.in b/demucs/MANIFEST.in
new file mode 100644
index 00000000..96e5f54f
--- /dev/null
+++ b/demucs/MANIFEST.in
@@ -0,0 +1,13 @@
+recursive-exclude env *
+recursive-include conf *.yaml
+include Makefile
+include LICENSE
+include demucs.png
+include outputs.tar.gz
+include test.mp3
+include requirements.txt
+include requirements_minimal.txt
+include mypy.ini
+include demucs/py.typed
+include demucs/remote/*.txt
+include demucs/remote/*.yaml
diff --git a/demucs/Makefile b/demucs/Makefile
new file mode 100644
index 00000000..0474d587
--- /dev/null
+++ b/demucs/Makefile
@@ -0,0 +1,36 @@
+all: linter tests
+
+linter:
+ flake8 demucs
+ mypy demucs
+
+tests: test_train test_eval
+
+test_train: tests/musdb
+ _DORA_TEST_PATH=/tmp/demucs python3 -m dora run --clear \
+ dset.musdb=./tests/musdb dset.segment=4 dset.shift=2 epochs=2 model=demucs \
+ demucs.depth=2 demucs.channels=4 test.sdr=false misc.num_workers=0 test.workers=0 \
+ test.shifts=0
+
+test_eval:
+ python3 -m demucs -n demucs_unittest test.mp3
+ python3 -m demucs -n demucs_unittest --two-stems=vocals test.mp3
+ python3 -m demucs -n demucs_unittest --mp3 test.mp3
+ python3 -m demucs -n demucs_unittest --flac --int24 test.mp3
+ python3 -m demucs -n demucs_unittest --int24 --clip-mode clamp test.mp3
+ python3 -m demucs -n demucs_unittest --segment 8 test.mp3
+ python3 -m demucs.api -n demucs_unittest --segment 8 test.mp3
+ python3 -m demucs --list-models
+
+tests/musdb:
+ test -e tests || mkdir tests
+ python3 -c 'import musdb; musdb.DB("tests/tmp", download=True)'
+ musdbconvert tests/tmp tests/musdb
+
+dist:
+ python3 setup.py sdist
+
+clean:
+ rm -r dist build *.egg-info
+
+.PHONY: linter dist test_train test_eval
diff --git a/demucs/README.md b/demucs/README.md
new file mode 100644
index 00000000..1bc16ee6
--- /dev/null
+++ b/demucs/README.md
@@ -0,0 +1,319 @@
+# Demucs Music Source Separation
+
+
+
+
+
+**This is the officially maintained Demucs** now that I (Alexandre Défossez) have left Meta to join [Kyutai](https://twitter.com/kyutai_labs).
+Note that I'm not actively working on Demucs anymore, so expect slow replies and no new feature for now.
+
+
+
+This is the 4th release of Demucs (v4), featuring Hybrid Transformer based source separation.
+**For the classic Hybrid Demucs (v3):** [Go this commit][demucs_v3].
+If you are experiencing issues and want the old Demucs back, please file an issue, and then you can get back to Demucs v3 with
+`git checkout v3`. You can also go [Demucs v2][demucs_v2].
+
+
+Demucs is a state-of-the-art music source separation model, currently capable of separating
+drums, bass, and vocals from the rest of the accompaniment.
+Demucs is based on a U-Net convolutional architecture inspired by [Wave-U-Net][waveunet].
+The v4 version features [Hybrid Transformer Demucs][htdemucs], a hybrid spectrogram/waveform separation model using Transformers.
+It is based on [Hybrid Demucs][hybrid_paper] (also provided in this repo), with the innermost layers
+replaced by a cross-domain Transformer Encoder. This Transformer uses self-attention within each domain,
+and cross-attention across domains.
+The model achieves a SDR of 9.00 dB on the MUSDB HQ test set. Moreover, when using sparse attention
+kernels to extend its receptive field and per source fine-tuning, we achieve state-of-the-art 9.20 dB of SDR.
+
+Samples are available [on our sample page](https://ai.honu.io/papers/htdemucs/index.html).
+Checkout [our paper][htdemucs] for more information.
+It has been trained on the [MUSDB HQ][musdb] dataset + an extra training dataset of 800 songs.
+This model separates drums, bass and vocals and other stems for any song.
+
+
+As Hybrid Transformer Demucs is brand new, it is not activated by default, you can activate it in the usual
+commands described hereafter with `-n htdemucs_ft`.
+The single, non fine-tuned model is provided as `-n htdemucs`, and the retrained baseline
+as `-n hdemucs_mmi`. The Sparse Hybrid Transformer model decribed in our paper is not provided as its
+requires custom CUDA code that is not ready for release yet.
+We are also releasing an experimental 6 sources model, that adds a `guitar` and `piano` source.
+Quick testing seems to show okay quality for `guitar`, but a lot of bleeding and artifacts for the `piano` source.
+
+
+
+
+
+
+
+## Important news if you are already using Demucs
+
+See the [release notes](./docs/release.md) for more details.
+
+- 22/02/2023: added support for the [SDX 2023 Challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023),
+ see the dedicated [doc page](./docs/sdx23.md)
+- 07/12/2022: Demucs v4 now on PyPI. **htdemucs** model now used by default. Also releasing
+ a 6 sources models (adding `guitar` and `piano`, although the latter doesn't work so well at the moment).
+- 16/11/2022: Added the new **Hybrid Transformer Demucs v4** models.
+ Adding support for the [torchaudio implementation of HDemucs](https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html).
+- 30/08/2022: added reproducibility and ablation grids, along with an updated version of the paper.
+- 17/08/2022: Releasing v3.0.5: Set split segment length to reduce memory. Compatible with pyTorch 1.12.
+- 24/02/2022: Releasing v3.0.4: split into two stems (i.e. karaoke mode).
+ Export as float32 or int24.
+- 17/12/2021: Releasing v3.0.3: bug fixes (thanks @keunwoochoi), memory drastically
+ reduced on GPU (thanks @famzah) and new multi-core evaluation on CPU (`-j` flag).
+- 12/11/2021: Releasing **Demucs v3** with hybrid domain separation. Strong improvements
+ on all sources. This is the model that won Sony MDX challenge.
+- 11/05/2021: Adding support for MusDB-HQ and arbitrary wav set, for the MDX challenge. For more information
+on joining the challenge with Demucs see [the Demucs MDX instructions](docs/mdx.md)
+
+
+## Comparison with other models
+
+We provide hereafter a summary of the different metrics presented in the paper.
+You can also compare Hybrid Demucs (v3), [KUIELAB-MDX-Net][kuielab], [Spleeter][spleeter], Open-Unmix, Demucs (v1), and Conv-Tasnet on one of my favorite
+songs on my [soundcloud playlist][soundcloud].
+
+### Comparison of accuracy
+
+`Overall SDR` is the mean of the SDR for each of the 4 sources, `MOS Quality` is a rating from 1 to 5
+of the naturalness and absence of artifacts given by human listeners (5 = no artifacts), `MOS Contamination`
+is a rating from 1 to 5 with 5 being zero contamination by other sources. We refer the reader to our [paper][hybrid_paper],
+for more details.
+
+| Model | Domain | Extra data? | Overall SDR | MOS Quality | MOS Contamination |
+|------------------------------|-------------|-------------------|-------------|-------------|-------------------|
+| [Wave-U-Net][waveunet] | waveform | no | 3.2 | - | - |
+| [Open-Unmix][openunmix] | spectrogram | no | 5.3 | - | - |
+| [D3Net][d3net] | spectrogram | no | 6.0 | - | - |
+| [Conv-Tasnet][demucs_v2] | waveform | no | 5.7 | - | |
+| [Demucs (v2)][demucs_v2] | waveform | no | 6.3 | 2.37 | 2.36 |
+| [ResUNetDecouple+][decouple] | spectrogram | no | 6.7 | - | - |
+| [KUIELAB-MDX-Net][kuielab] | hybrid | no | 7.5 | **2.86** | 2.55 |
+| [Band-Spit RNN][bandsplit] | spectrogram | no | **8.2** | - | - |
+| **Hybrid Demucs (v3)** | hybrid | no | 7.7 | **2.83** | **3.04** |
+| [MMDenseLSTM][mmdenselstm] | spectrogram | 804 songs | 6.0 | - | - |
+| [D3Net][d3net] | spectrogram | 1.5k songs | 6.7 | - | - |
+| [Spleeter][spleeter] | spectrogram | 25k songs | 5.9 | - | - |
+| [Band-Spit RNN][bandsplit] | spectrogram | 1.7k (mixes only) | **9.0** | - | - |
+| **HT Demucs f.t. (v4)** | hybrid | 800 songs | **9.0** | - | - |
+
+
+
+## Requirements
+
+You will need at least Python 3.8. See `requirements_minimal.txt` for requirements for separation only,
+and `environment-[cpu|cuda].yml` (or `requirements.txt`) if you want to train a new model.
+
+### For Windows users
+
+Everytime you see `python3`, replace it with `python.exe`. You should always run commands from the
+Anaconda console.
+
+### For musicians
+
+If you just want to use Demucs to separate tracks, you can install it with
+
+```bash
+python3 -m pip install -U demucs
+```
+
+For bleeding edge versions, you can install directly from this repo using
+```bash
+python3 -m pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs
+```
+
+Advanced OS support are provided on the following page, **you must read the page for your OS before posting an issues**:
+- **If you are using Windows:** [Windows support](docs/windows.md).
+- **If you are using macOS:** [macOS support](docs/mac.md).
+- **If you are using Linux:** [Linux support](docs/linux.md).
+
+### For machine learning scientists
+
+If you have anaconda installed, you can run from the root of this repository:
+
+```bash
+conda env update -f environment-cpu.yml # if you don't have GPUs
+conda env update -f environment-cuda.yml # if you have GPUs
+conda activate demucs
+pip install -e .
+```
+
+This will create a `demucs` environment with all the dependencies installed.
+
+You will also need to install [soundstretch/soundtouch](https://www.surina.net/soundtouch/soundstretch.html): on macOS you can do `brew install sound-touch`,
+and on Ubuntu `sudo apt-get install soundstretch`. This is used for the
+pitch/tempo augmentation.
+
+
+### Running in Docker
+
+Thanks to @xserrat, there is now a Docker image definition ready for using Demucs. This can ensure all libraries are correctly installed without interfering with the host OS. See his repo [Docker Facebook Demucs](https://github.com/xserrat/docker-facebook-demucs) for more information.
+
+
+### Running from Colab
+
+I made a Colab to easily separate track with Demucs. Note that
+transfer speeds with Colab are a bit slow for large media files,
+but it will allow you to use Demucs without installing anything.
+
+[Demucs on Google Colab](https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing)
+
+### Web Demo
+
+Integrated to [Hugging Face Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [](https://huggingface.co/spaces/akhaliq/demucs)
+
+### Graphical Interface
+
+@CarlGao4 has released a GUI for Demucs: [CarlGao4/Demucs-Gui](https://github.com/CarlGao4/Demucs-Gui). Downloads for Windows and macOS is available [here](https://github.com/CarlGao4/Demucs-Gui/releases). Use [FossHub mirror](https://fosshub.com/Demucs-GUI.html) to speed up your download.
+
+@Anjok07 is providing a self contained GUI in [UVR (Ultimate Vocal Remover)](https://github.com/facebookresearch/demucs/issues/334) that supports Demucs.
+
+### Other providers
+
+Audiostrip is providing free online separation with Demucs on their website [https://audiostrip.co.uk/](https://audiostrip.co.uk/).
+
+[MVSep](https://mvsep.com/) also provides free online separation, select `Demucs3 model B` for the best quality.
+
+[Neutone](https://neutone.space/) provides a realtime Demucs model in their free VST/AU plugin that can be used in your favorite DAW.
+
+
+## Separating tracks
+
+In order to try Demucs, you can just run from any folder (as long as you properly installed it)
+
+```bash
+demucs PATH_TO_AUDIO_FILE_1 [PATH_TO_AUDIO_FILE_2 ...] # for Demucs
+# If you used `pip install --user` you might need to replace demucs with python3 -m demucs
+python3 -m demucs --mp3 --mp3-bitrate BITRATE PATH_TO_AUDIO_FILE_1 # output files saved as MP3
+ # use --mp3-preset to change encoder preset, 2 for best quality, 7 for fastest
+# If your filename contain spaces don't forget to quote it !!!
+demucs "my music/my favorite track.mp3"
+# You can select different models with `-n` mdx_q is the quantized model, smaller but maybe a bit less accurate.
+demucs -n mdx_q myfile.mp3
+# If you only want to separate vocals out of an audio, use `--two-stems=vocals` (You can also set to drums or bass)
+demucs --two-stems=vocals myfile.mp3
+```
+
+
+If you have a GPU, but you run out of memory, please use `--segment SEGMENT` to reduce length of each split. `SEGMENT` should be changed to a integer describing the length of each segment in seconds.
+A segment length of at least 10 is recommended (the bigger the number is, the more memory is required, but quality may increase). Note that the Hybrid Transformer models only support a maximum segment length of 7.8 seconds.
+Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` is also helpful. If this still does not help, please add `-d cpu` to the command line. See the section hereafter for more details on the memory requirements for GPU acceleration.
+
+Separated tracks are stored in the `separated/MODEL_NAME/TRACK_NAME` folder. There you will find four stereo wav files sampled at 44.1 kHz: `drums.wav`, `bass.wav`,
+`other.wav`, `vocals.wav` (or `.mp3` if you used the `--mp3` option).
+
+All audio formats supported by `torchaudio` can be processed (i.e. wav, mp3, flac, ogg/vorbis on Linux/macOS, etc.). On Windows, `torchaudio` has limited support, so we rely on `ffmpeg`, which should support pretty much anything.
+Audio is resampled on the fly if necessary.
+The output will be a wav file encoded as int16.
+You can save as float32 wav files with `--float32`, or 24 bits integer wav with `--int24`.
+You can pass `--mp3` to save as mp3 instead, and set the bitrate (in kbps) with `--mp3-bitrate` (default is 320).
+
+It can happen that the output would need clipping, in particular due to some separation artifacts.
+Demucs will automatically rescale each output stem so as to avoid clipping. This can however break
+the relative volume between stems. If instead you prefer hard clipping, pass `--clip-mode clamp`.
+You can also try to reduce the volume of the input mixture before feeding it to Demucs.
+
+
+Other pre-trained models can be selected with the `-n` flag.
+The list of pre-trained models is:
+- `htdemucs`: first version of Hybrid Transformer Demucs. Trained on MusDB + 800 songs. Default model.
+- `htdemucs_ft`: fine-tuned version of `htdemucs`, separation will take 4 times more time
+ but might be a bit better. Same training set as `htdemucs`.
+- `htdemucs_6s`: 6 sources version of `htdemucs`, with `piano` and `guitar` being added as sources.
+ Note that the `piano` source is not working great at the moment.
+- `hdemucs_mmi`: Hybrid Demucs v3, retrained on MusDB + 800 songs.
+- `mdx`: trained only on MusDB HQ, winning model on track A at the [MDX][mdx] challenge.
+- `mdx_extra`: trained with extra training data (**including MusDB test set**), ranked 2nd on the track B
+ of the [MDX][mdx] challenge.
+- `mdx_q`, `mdx_extra_q`: quantized version of the previous models. Smaller download and storage
+ but quality can be slightly worse.
+- `SIG`: where `SIG` is a single model from the [model zoo](docs/training.md#model-zoo).
+
+The `--two-stems=vocals` option allows separating vocals from the rest of the accompaniment (i.e., karaoke mode).
+`vocals` can be changed to any source in the selected model.
+This will mix the files after separating the mix fully, so this won't be faster or use less memory.
+
+The `--shifts=SHIFTS` performs multiple predictions with random shifts (a.k.a the *shift trick*) of the input and average them. This makes prediction `SHIFTS` times
+slower. Don't use it unless you have a GPU.
+
+The `--overlap` option controls the amount of overlap between prediction windows. Default is 0.25 (i.e. 25%) which is probably fine.
+It can probably be reduced to 0.1 to improve a bit speed.
+
+
+The `-j` flag allow to specify a number of parallel jobs (e.g. `demucs -j 2 myfile.mp3`).
+This will multiply by the same amount the RAM used so be careful!
+
+### Memory requirements for GPU acceleration
+
+If you want to use GPU acceleration, you will need at least 3GB of RAM on your GPU for `demucs`. However, about 7GB of RAM will be required if you use the default arguments. Add `--segment SEGMENT` to change size of each split. If you only have 3GB memory, set SEGMENT to 8 (though quality may be worse if this argument is too small). Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` can help users with even smaller RAM such as 2GB (I separated a track that is 4 minutes but only 1.5GB is used), but this would make the separation slower.
+
+If you do not have enough memory on your GPU, simply add `-d cpu` to the command line to use the CPU. With Demucs, processing time should be roughly equal to 1.5 times the duration of the track.
+
+## Calling from another Python program
+
+The main function provides an `opt` parameter as a simple API. You can just pass the parsed command line as this parameter:
+```python
+# Assume that your command is `demucs --mp3 --two-stems vocals -n mdx_extra "track with space.mp3"`
+# The following codes are same as the command above:
+import demucs.separate
+demucs.separate.main(["--mp3", "--two-stems", "vocals", "-n", "mdx_extra", "track with space.mp3"])
+
+# Or like this
+import demucs.separate
+import shlex
+demucs.separate.main(shlex.split('--mp3 --two-stems vocals -n mdx_extra "track with space.mp3"'))
+```
+
+To use more complicated APIs, see [API docs](docs/api.md)
+
+## Training Demucs
+
+If you want to train (Hybrid) Demucs, please follow the [training doc](docs/training.md).
+
+## MDX Challenge reproduction
+
+In order to reproduce the results from the Track A and Track B submissions, checkout the [MDX Hybrid Demucs submission repo][mdx_submission].
+
+
+
+## How to cite
+
+```
+@inproceedings{rouard2022hybrid,
+ title={Hybrid Transformers for Music Source Separation},
+ author={Rouard, Simon and Massa, Francisco and D{\'e}fossez, Alexandre},
+ booktitle={ICASSP 23},
+ year={2023}
+}
+
+@inproceedings{defossez2021hybrid,
+ title={Hybrid Spectrogram and Waveform Source Separation},
+ author={D{\'e}fossez, Alexandre},
+ booktitle={Proceedings of the ISMIR 2021 Workshop on Music Source Separation},
+ year={2021}
+}
+```
+
+## License
+
+Demucs is released under the MIT license as found in the [LICENSE](LICENSE) file.
+
+[hybrid_paper]: https://arxiv.org/abs/2111.03600
+[waveunet]: https://github.com/f90/Wave-U-Net
+[musdb]: https://sigsep.github.io/datasets/musdb.html
+[openunmix]: https://github.com/sigsep/open-unmix-pytorch
+[mmdenselstm]: https://arxiv.org/abs/1805.02410
+[demucs_v2]: https://github.com/facebookresearch/demucs/tree/v2
+[demucs_v3]: https://github.com/facebookresearch/demucs/tree/v3
+[spleeter]: https://github.com/deezer/spleeter
+[soundcloud]: https://soundcloud.com/honualx/sets/source-separation-in-the-waveform-domain
+[d3net]: https://arxiv.org/abs/2010.01733
+[mdx]: https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021
+[kuielab]: https://github.com/kuielab/mdx-net-submission
+[decouple]: https://arxiv.org/abs/2109.05418
+[mdx_submission]: https://github.com/adefossez/mdx21_demucs
+[bandsplit]: https://arxiv.org/abs/2209.15174
+[htdemucs]: https://arxiv.org/abs/2211.08553
diff --git a/demucs/Readme.md b/demucs/Readme.md
deleted file mode 100644
index 402d2b4a..00000000
--- a/demucs/Readme.md
+++ /dev/null
@@ -1 +0,0 @@
-Inlined demucs https://github.com/adefossez/demucs
\ No newline at end of file
diff --git a/demucs/conf/config.yaml b/demucs/conf/config.yaml
new file mode 100644
index 00000000..d2597cb5
--- /dev/null
+++ b/demucs/conf/config.yaml
@@ -0,0 +1,304 @@
+defaults:
+ - _self_
+ - dset: musdb44
+ - svd: default
+ - variant: default
+ - override hydra/hydra_logging: colorlog
+ - override hydra/job_logging: colorlog
+
+dummy:
+dset:
+ musdb: /checkpoint/defossez/datasets/musdbhq
+ musdb_samplerate: 44100
+ use_musdb: true # set to false to not use musdb as training data.
+ wav: # path to custom wav dataset
+ wav2: # second custom wav dataset
+ segment: 11
+ shift: 1
+ train_valid: false
+ full_cv: true
+ samplerate: 44100
+ channels: 2
+ normalize: true
+ metadata: ./metadata
+ sources: ['drums', 'bass', 'other', 'vocals']
+ valid_samples: # valid dataset size
+ backend: null # if provided select torchaudio backend.
+
+test:
+ save: False
+ best: True
+ workers: 2
+ every: 20
+ split: true
+ shifts: 1
+ overlap: 0.25
+ sdr: true
+ metric: 'loss' # metric used for best model selection on the valid set, can also be nsdr
+ nonhq: # path to non hq MusDB for evaluation
+
+epochs: 360
+batch_size: 64
+max_batches: # limit the number of batches per epoch, useful for debugging
+ # or if your dataset is gigantic.
+optim:
+ lr: 3e-4
+ momentum: 0.9
+ beta2: 0.999
+ loss: l1 # l1 or mse
+ optim: adam
+ weight_decay: 0
+ clip_grad: 0
+
+seed: 42
+debug: false
+valid_apply: true
+flag:
+save_every:
+weights: [1., 1., 1., 1.] # weights over each source for the training/valid loss.
+
+augment:
+ shift_same: false
+ repitch:
+ proba: 0.2
+ max_tempo: 12
+ remix:
+ proba: 1
+ group_size: 4
+ scale:
+ proba: 1
+ min: 0.25
+ max: 1.25
+ flip: true
+
+continue_from: # continue from other XP, give the XP Dora signature.
+continue_pretrained: # signature of a pretrained XP, this cannot be a bag of models.
+pretrained_repo: # repo for pretrained model (default is official AWS)
+continue_best: true
+continue_opt: false
+
+misc:
+ num_workers: 10
+ num_prints: 4
+ show: false
+ verbose: false
+
+# List of decay for EMA at batch or epoch level, e.g. 0.999.
+# Batch level EMA are kept on GPU for speed.
+ema:
+ epoch: []
+ batch: []
+
+use_train_segment: true # to remove
+model_segment: # override the segment parameter for the model, usually 4 times the training segment.
+model: demucs # see demucs/train.py for the possibilities, and config for each model hereafter.
+demucs: # see demucs/demucs.py for a detailed description
+ # Channels
+ channels: 64
+ growth: 2
+ # Main structure
+ depth: 6
+ rewrite: true
+ lstm_layers: 0
+ # Convolutions
+ kernel_size: 8
+ stride: 4
+ context: 1
+ # Activations
+ gelu: true
+ glu: true
+ # Normalization
+ norm_groups: 4
+ norm_starts: 4
+ # DConv residual branch
+ dconv_depth: 2
+ dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both.
+ dconv_comp: 4
+ dconv_attn: 4
+ dconv_lstm: 4
+ dconv_init: 1e-4
+ # Pre/post treatment
+ resample: true
+ normalize: false
+ # Weight init
+ rescale: 0.1
+
+hdemucs: # see demucs/hdemucs.py for a detailed description
+ # Channels
+ channels: 48
+ channels_time:
+ growth: 2
+ # STFT
+ nfft: 4096
+ wiener_iters: 0
+ end_iters: 0
+ wiener_residual: false
+ cac: true
+ # Main structure
+ depth: 6
+ rewrite: true
+ hybrid: true
+ hybrid_old: false
+ # Frequency Branch
+ multi_freqs: []
+ multi_freqs_depth: 3
+ freq_emb: 0.2
+ emb_scale: 10
+ emb_smooth: true
+ # Convolutions
+ kernel_size: 8
+ stride: 4
+ time_stride: 2
+ context: 1
+ context_enc: 0
+ # normalization
+ norm_starts: 4
+ norm_groups: 4
+ # DConv residual branch
+ dconv_mode: 1
+ dconv_depth: 2
+ dconv_comp: 4
+ dconv_attn: 4
+ dconv_lstm: 4
+ dconv_init: 1e-3
+ # Weight init
+ rescale: 0.1
+
+# Torchaudio implementation of HDemucs
+torch_hdemucs:
+# Channels
+ channels: 48
+ growth: 2
+ # STFT
+ nfft: 4096
+ # Main structure
+ depth: 6
+ freq_emb: 0.2
+ emb_scale: 10
+ emb_smooth: true
+ # Convolutions
+ kernel_size: 8
+ stride: 4
+ time_stride: 2
+ context: 1
+ context_enc: 0
+ # normalization
+ norm_starts: 4
+ norm_groups: 4
+ # DConv residual branch
+ dconv_depth: 2
+ dconv_comp: 4
+ dconv_attn: 4
+ dconv_lstm: 4
+ dconv_init: 1e-3
+
+htdemucs: # see demucs/htdemucs.py for a detailed description
+ # Channels
+ channels: 48
+ channels_time:
+ growth: 2
+ # STFT
+ nfft: 4096
+ wiener_iters: 0
+ end_iters: 0
+ wiener_residual: false
+ cac: true
+ # Main structure
+ depth: 4
+ rewrite: true
+ # Frequency Branch
+ multi_freqs: []
+ multi_freqs_depth: 3
+ freq_emb: 0.2
+ emb_scale: 10
+ emb_smooth: true
+ # Convolutions
+ kernel_size: 8
+ stride: 4
+ time_stride: 2
+ context: 1
+ context_enc: 0
+ # normalization
+ norm_starts: 4
+ norm_groups: 4
+ # DConv residual branch
+ dconv_mode: 1
+ dconv_depth: 2
+ dconv_comp: 8
+ dconv_init: 1e-3
+ # Before the Transformer
+ bottom_channels: 0
+ # CrossTransformer
+ # ------ Common to all
+ # Regular parameters
+ t_layers: 5
+ t_hidden_scale: 4.0
+ t_heads: 8
+ t_dropout: 0.0
+ t_layer_scale: True
+ t_gelu: True
+ # ------------- Positional Embedding
+ t_emb: sin
+ t_max_positions: 10000 # for the scaled embedding
+ t_max_period: 10000.0
+ t_weight_pos_embed: 1.0
+ t_cape_mean_normalize: True
+ t_cape_augment: True
+ t_cape_glob_loc_scale: [5000.0, 1.0, 1.4]
+ t_sin_random_shift: 0
+ # ------------- norm before a transformer encoder
+ t_norm_in: True
+ t_norm_in_group: False
+ # ------------- norm inside the encoder
+ t_group_norm: False
+ t_norm_first: True
+ t_norm_out: True
+ # ------------- optim
+ t_weight_decay: 0.0
+ t_lr:
+ # ------------- sparsity
+ t_sparse_self_attn: False
+ t_sparse_cross_attn: False
+ t_mask_type: diag
+ t_mask_random_seed: 42
+ t_sparse_attn_window: 400
+ t_global_window: 100
+ t_sparsity: 0.95
+ t_auto_sparsity: False
+ # Cross Encoder First (False)
+ t_cross_first: False
+ # Weight init
+ rescale: 0.1
+
+svd: # see svd.py for documentation
+ penalty: 0
+ min_size: 0.1
+ dim: 1
+ niters: 2
+ powm: false
+ proba: 1
+ conv_only: false
+ convtr: false
+ bs: 1
+
+quant: # quantization hyper params
+ diffq: # diffq penalty, typically 1e-4 or 3e-4
+ qat: # use QAT with a fixed number of bits (not as good as diffq)
+ min_size: 0.2
+ group_size: 8
+
+dora:
+ dir: outputs
+ exclude: ["misc.*", "slurm.*", 'test.reval', 'flag', 'dset.backend']
+
+slurm:
+ time: 4320
+ constraint: volta32gb
+ setup: ['module load cudnn/v8.4.1.50-cuda.11.6 NCCL/2.11.4-6-cuda.11.6 cuda/11.6']
+
+# Hydra config
+hydra:
+ job_logging:
+ formatters:
+ colorlog:
+ datefmt: "%m-%d %H:%M:%S"
diff --git a/demucs/conf/dset/aetl.yaml b/demucs/conf/dset/aetl.yaml
new file mode 100644
index 00000000..7c983160
--- /dev/null
+++ b/demucs/conf/dset/aetl.yaml
@@ -0,0 +1,19 @@
+# @package _global_
+
+# automix dataset with Musdb, extra training data and the test set of Musdb.
+# This used even more remixes than auto_extra_test.
+dset:
+ wav: /checkpoint/defossez/datasets/aetl
+ samplerate: 44100
+ channels: 2
+epochs: 320
+max_batches: 500
+
+augment:
+ shift_same: true
+ scale:
+ proba: 0.
+ remix:
+ proba: 0
+ repitch:
+ proba: 0
diff --git a/demucs/conf/dset/auto_extra_test.yaml b/demucs/conf/dset/auto_extra_test.yaml
new file mode 100644
index 00000000..056183a5
--- /dev/null
+++ b/demucs/conf/dset/auto_extra_test.yaml
@@ -0,0 +1,18 @@
+# @package _global_
+
+# automix dataset with Musdb, extra training data and the test set of Musdb.
+dset:
+ wav: /checkpoint/defossez/datasets/automix_extra_test2
+ samplerate: 44100
+ channels: 2
+epochs: 320
+max_batches: 500
+
+augment:
+ shift_same: true
+ scale:
+ proba: 0.
+ remix:
+ proba: 0
+ repitch:
+ proba: 0
diff --git a/demucs/conf/dset/auto_mus.yaml b/demucs/conf/dset/auto_mus.yaml
new file mode 100644
index 00000000..9a2d9df5
--- /dev/null
+++ b/demucs/conf/dset/auto_mus.yaml
@@ -0,0 +1,20 @@
+# @package _global_
+
+# Automix dataset based on musdb train set.
+dset:
+ wav: /checkpoint/defossez/datasets/automix_musdb
+ samplerate: 44100
+ channels: 2
+epochs: 360
+max_batches: 300
+test:
+ every: 4
+
+augment:
+ shift_same: true
+ scale:
+ proba: 0.5
+ remix:
+ proba: 0
+ repitch:
+ proba: 0
diff --git a/demucs/conf/dset/extra44.yaml b/demucs/conf/dset/extra44.yaml
new file mode 100644
index 00000000..f0adc467
--- /dev/null
+++ b/demucs/conf/dset/extra44.yaml
@@ -0,0 +1,8 @@
+# @package _global_
+
+# Musdb + extra tracks
+dset:
+ wav: /checkpoint/defossez/datasets/allstems_44/
+ samplerate: 44100
+ channels: 2
+epochs: 320
diff --git a/demucs/conf/dset/extra_mmi_goodclean.yaml b/demucs/conf/dset/extra_mmi_goodclean.yaml
new file mode 100644
index 00000000..fe47bcf2
--- /dev/null
+++ b/demucs/conf/dset/extra_mmi_goodclean.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+# Musdb + extra tracks
+dset:
+ wav: /checkpoint/defossez/datasets/allstems_44/
+ wav2: /checkpoint/defossez/datasets/mmi44_goodclean
+ samplerate: 44100
+ channels: 2
+ wav2_weight: null
+ wav2_valid: false
+ valid_samples: 100
+epochs: 1200
diff --git a/demucs/conf/dset/extra_test.yaml b/demucs/conf/dset/extra_test.yaml
new file mode 100644
index 00000000..1e7d05ad
--- /dev/null
+++ b/demucs/conf/dset/extra_test.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+# Musdb + extra tracks + test set from musdb.
+dset:
+ wav: /checkpoint/defossez/datasets/allstems_test_44/
+ samplerate: 44100
+ channels: 2
+epochs: 320
+max_batches: 700
+test:
+ sdr: false
+ every: 500
diff --git a/demucs/conf/dset/musdb44.yaml b/demucs/conf/dset/musdb44.yaml
new file mode 100644
index 00000000..c5623468
--- /dev/null
+++ b/demucs/conf/dset/musdb44.yaml
@@ -0,0 +1,5 @@
+# @package _global_
+
+dset:
+ samplerate: 44100
+ channels: 2
\ No newline at end of file
diff --git a/demucs/conf/dset/sdx23_bleeding.yaml b/demucs/conf/dset/sdx23_bleeding.yaml
new file mode 100644
index 00000000..5f7fd1e4
--- /dev/null
+++ b/demucs/conf/dset/sdx23_bleeding.yaml
@@ -0,0 +1,10 @@
+# @package _global_
+
+# Musdb + extra tracks
+dset:
+ wav: /shared/home/defossez/data/datasets/moisesdb23_bleeding_v1.0/
+ use_musdb: false
+ samplerate: 44100
+ channels: 2
+ backend: soundfile # must use soundfile as some mixture would clip with sox.
+epochs: 320
diff --git a/demucs/conf/dset/sdx23_labelnoise.yaml b/demucs/conf/dset/sdx23_labelnoise.yaml
new file mode 100644
index 00000000..367769e6
--- /dev/null
+++ b/demucs/conf/dset/sdx23_labelnoise.yaml
@@ -0,0 +1,10 @@
+# @package _global_
+
+# Musdb + extra tracks
+dset:
+ wav: /shared/home/defossez/data/datasets/moisesdb23_labelnoise_v1.0
+ use_musdb: false
+ samplerate: 44100
+ channels: 2
+ backend: soundfile # must use soundfile as some mixture would clip with sox.
+epochs: 320
diff --git a/demucs/conf/svd/base.yaml b/demucs/conf/svd/base.yaml
new file mode 100644
index 00000000..e4de8685
--- /dev/null
+++ b/demucs/conf/svd/base.yaml
@@ -0,0 +1,14 @@
+# @package _global_
+
+svd:
+ penalty: 0
+ min_size: 1
+ dim: 50
+ niters: 4
+ powm: false
+ proba: 1
+ conv_only: false
+ convtr: false # ideally this should be true, but some models were trained with this to false.
+
+optim:
+ beta2: 0.9998
\ No newline at end of file
diff --git a/demucs/conf/svd/base2.yaml b/demucs/conf/svd/base2.yaml
new file mode 100644
index 00000000..b88a7519
--- /dev/null
+++ b/demucs/conf/svd/base2.yaml
@@ -0,0 +1,14 @@
+# @package _global_
+
+svd:
+ penalty: 0
+ min_size: 1
+ dim: 100
+ niters: 4
+ powm: false
+ proba: 1
+ conv_only: false
+ convtr: true
+
+optim:
+ beta2: 0.9998
\ No newline at end of file
diff --git a/demucs/conf/svd/default.yaml b/demucs/conf/svd/default.yaml
new file mode 100644
index 00000000..03bfe3db
--- /dev/null
+++ b/demucs/conf/svd/default.yaml
@@ -0,0 +1 @@
+# @package _global_
diff --git a/demucs/conf/variant/default.yaml b/demucs/conf/variant/default.yaml
new file mode 100644
index 00000000..03bfe3db
--- /dev/null
+++ b/demucs/conf/variant/default.yaml
@@ -0,0 +1 @@
+# @package _global_
diff --git a/demucs/conf/variant/example.yaml b/demucs/conf/variant/example.yaml
new file mode 100644
index 00000000..9b38aeca
--- /dev/null
+++ b/demucs/conf/variant/example.yaml
@@ -0,0 +1,5 @@
+# @package _global_
+
+model: hdemucs
+hdemucs:
+ channels: 32
\ No newline at end of file
diff --git a/demucs/conf/variant/finetune.yaml b/demucs/conf/variant/finetune.yaml
new file mode 100644
index 00000000..c3ea21ed
--- /dev/null
+++ b/demucs/conf/variant/finetune.yaml
@@ -0,0 +1,19 @@
+# @package _global_
+
+epochs: 4
+batch_size: 16
+optim:
+ lr: 0.0006
+test:
+ every: 1
+ sdr: false
+dset:
+ segment: 28
+ shift: 2
+
+augment:
+ scale:
+ proba: 0
+ shift_same: true
+ remix:
+ proba: 0
diff --git a/demucs/demucs.png b/demucs/demucs.png
new file mode 100644
index 0000000000000000000000000000000000000000..d043f64442f24d1825dfabb3eed57ff0f843f64a
GIT binary patch
literal 339294
zcmeFYg;O0*(=dv=90>02PH=bk;LgF_J-BmlC%6;b-3byL0tB}p!4B^7@q1o-?)~0>
z;8xwOn%dsho!;K*o!RMVRb?4eBmyJ|2nbX;SxI#W20Rcgm1T-W%t~w=-$v^y-m^Amq7Fqo5)zIG7!BOpAq;9-=`%$Ucdppg8ChpUjg#tOssbIiwyvXfX_i4|+vA$UgWF
zXc&v^88NXYC5$RBemG(k8H#9H`oM7{9S$Z(o}W&xh7NLW`VlrwD(f#kHBx*(Kwk4=
zQekO9jg5ni!^{LCoJaEZkZ{}C9BWTBs`T8%gv|yHX;U)Fq7P{DyD2I-7)&g#s7X53
zH`VV7gBW<(mmeeVeew`BAF9*AaTVzO**^5FOS{=T*L{{G$t3J!k6
zfgTEWg5WqcIV_)5$op`qmW7U-rIHc^!v~B20U2on0rLSt{^RVA_Tl;D2nhI(Gv-H=
zEP(niDs*@O^nb&@{$Ui;kdTx6IBNi1Ei4?}texCR<;j#kR4v(j(Q(&NQsf6ZIk1?T
zJDFLqcsn@%BLX4l&Hn*9Sh$;#c{|uUy77ApQT&U8{{#L9%}PP`FBW$@AqpKORWb=D
zR|_(37B&_(3SlHNGBQC|b4z}8N$LL-|F{yOuy%KM=4WN~^73Nw;$(4hwPI!GcJpy`H}z(Abff$aCI6*I(!vesYUAu~q^=ZI>pwmJH&6VB
zoBxIV@Ut+IAnSiWGhrmV!H8=J2vG<*NwF{9kf(<5-dLKrJuCL&;GX-<*n%5|s2DmN
ztS?xn6~9HLu)5xU`-x#mi2<+yQojFi_(f86WxeMRM@&vnFWFO&g9D5kyEHbt`K}Lo
z)|^0am{LfP;X&e}Rs1KcJ
zRx-AJN+C|>!s0R6e5IMa8JSf-m8o1A5<;Yz;J(GG?ftdO)3(vd$$D9@k6#y!^lB?T
zpQU8{poA*v^_gp>Q7Q%XB!!xs{AwacxMut>1YWlOX2c2cu*br;R(rX0POJI%&c9#t
zD?2u)bI0@8ld`!f_`C*59r-+&nys2#B&@;NeE}CFe4AdUvgv!H5?y6|twk*xXE<5H
z^tuoDpoh)rynRyWR(8AZ@2~cO13BiV0s}{but{0gjHv;zc8u$W
zHLP$w{>@zdOWS_A0;>}**~lY8*(^&*LVE`O8R6H?*ztSG<0koCDRe1dHI{}
zri!xeX=Pt8vA_Qot=?NC;VTSSX$djvckS=}n(hB0mRzm#)9NVIP5>5MT#A^V-xufU
zgV+0g(X7OBN{8M64=9{#v(ag;+lDhZ<*Bsl6_v*Cp?)8bbO~Ok7i-*3`@1qhpX($q
zq+BTKiZ&Uz(TI38aPQ(egc@9z73Os`(u&A&v+PZ8B#Q%M*=?iSb#c>hCrImFQWL
zd5l5rDO`Mh(&j__O7dy0viE*wm+RuB*MCI&UMH!sGDrgiZ~Rse$s5;{5IF)|ap-l%
za8&|latCjIVmtcm;xmoJ{4-{=!iWb#_b6+KU_oW!2Ggf(xDrf08R--bV*HvOvfYch
zHtfC&>umW`B-Q|QnfPJ1XUlm+8+1THI_c?!2R01Y{SzK#A5a@~JsNTKaa
z8wKy2lbidN==og3ykPZ+HUnpij_
zc+m}{maVx)h>Tzj6{>MVA+tZVU3-77RZ>crVJd!q5mkBVeH)}kCE#OoWMjch*)Oi!
z9mGVDNpKw>1$R-A_D+5+atznA>d4A;Jk?JYHY+@(;Y!
zg_}`ls^T1J$3xop2^X{wf
zb5^(f1K$Y1>8ncQC~+oQnHTd_WGH#XbReqB!lTfqT_hJc?|p2Cd}gGaFh+ANci>02tsU
zw>6q#KAEq}MjPg-nSq=snnVHV7_2pbcUo>XOx6p=o!-fKr|#$db#l$imIpU*5pC7Y
zT}Aql$RPmrflb(-|4Hf0Q0{v3QAQbttDn@#jzGGb4@&K7zb(lO6(5WYh
z>Xz`w7#{-6K-9VKV-VzM(8RlQ_pix1N`pqn(lKwr(Zf9X)KJq-;OpHtWT+qwUlBZ3
zE?eFb9rj2RQj!r=qhqOrOfIrNdhFc!Sr4%*
z!s|vieRuEKiUg015=j7+l#F3hnZ$Xhru#(~OG1?y?>#{c(B`c4nJavKU!gh3aovB*
zQ=D!x%h>$Z)~`AW=*o1!L??NAbYY4hbaAYMU@8c>)N6Kc31NHK1~y!A9>h5xM2I@9
zofVteZU<4>z*(%^4_PDL&2W46{GeA6IZ1{0qYgX8RNisK-CQ_>T+-;oX4N|qvg|1H
zO(iy+7FyfR>qbRFsb&AyaODK^(zmfWT50^RxGK0kqY8`j&0T1s?3cJxKmxS4r5fSMR@ekLy_l}B7-!DdVgBy4^3!0k1^Kq
z?X$tqay*&y&WF6+J0}17{)sZJinj!;;nIkw)gdZ=6un2G&Eqmwt3Nih6(}zWCu^pgz-=Ip*u9m{Z$c5nibhQ`8bfH69?tj_9xZ6&}5oH{3oGjpX
z6q=_M{(X_iWVUB(Nc1a}v_KmY(?SDD#xb`8P>7|Vu=f)vFqcy!rPaZGhlX}NP$4AP
z5@b*!sleyLpruIqsky`xjBz-SFzH?)Nq#Xv3*5op9?oGU~Tc`z=v-*_C+`x_I9+6GxKV^HL*=qlKtN#(M`c42$ZZIt0yA=mU%}
zt8IHbLU2ClWWzkkrTCb`^btY3(
zDT6h9@Oj}erk2A%P(URJd$T