Compare commits
No commits in common. "main" and "v1.2.0" have entirely different histories.
14
.coveragerc
|
|
@ -1,19 +1,7 @@
|
|||
[run]
|
||||
omit =
|
||||
buzz/whisper_cpp/*
|
||||
buzz/transcriber/local_whisper_cpp_server_transcriber.py
|
||||
buzz/whisper_cpp.py
|
||||
*_test.py
|
||||
demucs/*
|
||||
whisper_diarization/*
|
||||
deepmultilingualpunctuation/*
|
||||
ctc_forced_aligner/*
|
||||
|
||||
[report]
|
||||
exclude_also =
|
||||
if sys.platform == "win32":
|
||||
if platform.system\(\) == "Windows":
|
||||
if platform.system\(\) == "Linux":
|
||||
if platform.system\(\) == "Darwin":
|
||||
|
||||
[html]
|
||||
directory = coverage/html
|
||||
|
|
|
|||
307
.github/workflows/ci.yml
vendored
|
|
@ -1,3 +1,4 @@
|
|||
---
|
||||
name: CI
|
||||
on:
|
||||
push:
|
||||
|
|
@ -14,88 +15,72 @@ concurrency:
|
|||
jobs:
|
||||
test:
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
BUZZ_DISABLE_TELEMETRY: true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: macos-15-intel
|
||||
- os: macos-13
|
||||
- os: macos-latest
|
||||
- os: windows-latest
|
||||
- os: ubuntu-20.04
|
||||
- os: ubuntu-22.04
|
||||
- os: ubuntu-latest
|
||||
- os: ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
# Should be removed with next update to whisper.cpp
|
||||
- name: Downgrade Xcode
|
||||
uses: maxim-lobanov/setup-xcode@v1
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
xcode-version: '16.0.0'
|
||||
if: matrix.os == 'macos-latest'
|
||||
python-version: "3.11.9"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install Poetry Action
|
||||
uses: snok/install-poetry@v1.3.1
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install Vulkan SDK
|
||||
if: "startsWith(matrix.os, 'ubuntu-') || matrix.os == 'windows-latest'"
|
||||
uses: humbletim/install-vulkan-sdk@v1.2
|
||||
with:
|
||||
version: 1.4.309.0
|
||||
cache: true
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
||||
- name: Load cached venv
|
||||
id: cached-uv-dependencies
|
||||
id: cached-poetry-dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/uv.lock') }}
|
||||
key: venv-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- uses: AnimMouse/setup-ffmpeg@v1
|
||||
- name: Load cached Whisper models
|
||||
id: cached-whisper-models
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/Buzz
|
||||
~/.cache/whisper
|
||||
~/.cache/huggingface
|
||||
~/AppData/Local/Buzz/Buzz/Cache
|
||||
key: whisper-models
|
||||
|
||||
- uses: FedericoCarboni/setup-ffmpeg@v3.1
|
||||
id: setup-ffmpeg
|
||||
with:
|
||||
version: ${{ matrix.os == 'macos-15-intel' && '7.1.1' || matrix.os == 'macos-latest' && '80' || '8.0' }}
|
||||
ffmpeg-version: release
|
||||
architecture: 'x64'
|
||||
github-token: ${{ github.server_url == 'https://github.com' && github.token || '' }}
|
||||
|
||||
- name: Test ffmpeg
|
||||
run: ffmpeg -i ./testdata/audio-long.mp3 ./testdata/audio-long.wav
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v2
|
||||
if: runner.os == 'Windows'
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
|
||||
if [ "$(lsb_release -rs)" == "22.04" ]; then
|
||||
if [ "$(lsb_release -rs)" != "24.04" ]; then
|
||||
sudo apt-get install libegl1-mesa
|
||||
|
||||
# Add ubuntu-toolchain-r PPA for newer libstdc++6 with GLIBCXX_3.4.32
|
||||
sudo add-apt-repository ppa:ubuntu-toolchain-r/test -y
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libstdc++6
|
||||
fi
|
||||
|
||||
sudo apt-get install libyaml-dev libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-shape0 libxcb-cursor0 libportaudio2 gettext libpulse0 libgl1-mesa-dev libvulkan-dev ccache
|
||||
sudo apt-get install libyaml-dev libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-shape0 libxcb-cursor0 libportaudio2 gettext libpulse0 libgl1-mesa-dev
|
||||
if: "startsWith(matrix.os, 'ubuntu-')"
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
uv run make test
|
||||
poetry run make test
|
||||
shell: bash
|
||||
env:
|
||||
PYTHONFAULTHANDLER: "1"
|
||||
|
||||
- name: Upload coverage reports to Codecov with GitHub Action
|
||||
uses: codecov/codecov-action@v4
|
||||
|
|
@ -107,14 +92,11 @@ jobs:
|
|||
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
BUZZ_DISABLE_TELEMETRY: true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: macos-15-intel
|
||||
- os: macos-13
|
||||
- os: macos-latest
|
||||
- os: windows-latest
|
||||
steps:
|
||||
|
|
@ -122,130 +104,87 @@ jobs:
|
|||
with:
|
||||
submodules: recursive
|
||||
|
||||
# Should be removed with next update to whisper.cpp
|
||||
- name: Downgrade Xcode
|
||||
uses: maxim-lobanov/setup-xcode@v1
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
xcode-version: '16.0.0'
|
||||
if: matrix.os == 'macos-latest'
|
||||
python-version: "3.11.9"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install Poetry Action
|
||||
uses: snok/install-poetry@v1.3.1
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install Vulkan SDK
|
||||
if: "startsWith(matrix.os, 'ubuntu-') || matrix.os == 'windows-latest'"
|
||||
uses: humbletim/install-vulkan-sdk@v1.2
|
||||
with:
|
||||
version: 1.4.309.0
|
||||
cache: true
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
||||
- name: Load cached venv
|
||||
id: cached-uv-dependencies
|
||||
id: cached-poetry-dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/uv.lock') }}
|
||||
key: venv-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- name: Install Inno Setup on Windows
|
||||
uses: crazy-max/ghaction-chocolatey@v3
|
||||
- uses: FedericoCarboni/setup-ffmpeg@v3.1
|
||||
id: setup-ffmpeg
|
||||
with:
|
||||
args: install innosetup --yes
|
||||
if: runner.os == 'Windows'
|
||||
ffmpeg-version: release
|
||||
architecture: 'x64'
|
||||
github-token: ${{ github.server_url == 'https://github.com' && github.token || '' }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- uses: ruby/setup-ruby@v1
|
||||
with:
|
||||
ruby-version: "3.0" # Not needed with a .ruby-version file
|
||||
bundler-cache: true # runs 'bundle install' and caches installed gems automatically
|
||||
if: "startsWith(matrix.os, 'ubuntu-')"
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
|
||||
if [ "$(lsb_release -rs)" == "22.04" ]; then
|
||||
if [ "$(lsb_release -rs)" != "24.04" ]; then
|
||||
sudo apt-get install libegl1-mesa
|
||||
|
||||
# Add ubuntu-toolchain-r PPA for newer libstdc++6 with GLIBCXX_3.4.32
|
||||
sudo add-apt-repository ppa:ubuntu-toolchain-r/test -y
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libstdc++6
|
||||
fi
|
||||
|
||||
sudo apt-get install libyaml-dev libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-shape0 libxcb-cursor0 libportaudio2 gettext libpulse0 libgl1-mesa-dev libvulkan-dev ccache
|
||||
if: "startsWith(matrix.os, 'ubuntu-')"
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync
|
||||
|
||||
- uses: AnimMouse/setup-ffmpeg@v1
|
||||
id: setup-ffmpeg
|
||||
with:
|
||||
version: ${{ matrix.os == 'macos-15-intel' && '7.1.1' || matrix.os == 'macos-latest' && '80' || '8.0' }}
|
||||
|
||||
- name: Install MSVC for Windows
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
uv add msvc-runtime
|
||||
uv pip install -U torch==2.8.0+cu129 torchaudio==2.8.0+cu129 --index-url https://download.pytorch.org/whl/cu129
|
||||
uv pip install nvidia-cublas-cu12==12.9.1.4 nvidia-cuda-cupti-cu12==12.9.79 nvidia-cuda-runtime-cu12==12.9.79 --extra-index-url https://pypi.ngc.nvidia.com
|
||||
|
||||
uv cache clean
|
||||
uv run pip cache purge
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v2
|
||||
if: runner.os == 'Windows'
|
||||
|
||||
- uses: ruby/setup-ruby@v1
|
||||
with:
|
||||
ruby-version: "3.0"
|
||||
bundler-cache: true
|
||||
sudo apt-get install libyaml-dev libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-shape0 libxcb-cursor0 libportaudio2 gettext libpulse0 libgl1-mesa-dev
|
||||
if: "startsWith(matrix.os, 'ubuntu-')"
|
||||
|
||||
- name: Install FPM
|
||||
run: gem install fpm
|
||||
if: "startsWith(matrix.os, 'ubuntu-')"
|
||||
|
||||
- name: Clear space on Windows
|
||||
if: runner.os == 'Windows'
|
||||
run: |
|
||||
rm 'C:\Android\android-sdk\' -r -force
|
||||
rm 'C:\Program Files (x86)\Google\' -r -force
|
||||
rm 'C:\tools\kotlinc\' -r -force
|
||||
rm 'C:\tools\php\' -r -force
|
||||
rm 'C:\selenium\' -r -force
|
||||
shell: pwsh
|
||||
|
||||
- name: Bundle
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "macOS" ]; then
|
||||
|
||||
brew install create-dmg
|
||||
|
||||
|
||||
# kill XProtect to prevent https://github.com/actions/runner-images/issues/7522
|
||||
sudo pkill -9 XProtect >/dev/null || true;
|
||||
while pgrep XProtect; do sleep 3; done;
|
||||
|
||||
# create variables
|
||||
CERTIFICATE_PATH=$RUNNER_TEMP/build_certificate.p12
|
||||
KEYCHAIN_PATH=$RUNNER_TEMP/app-signing.keychain-db
|
||||
|
||||
# import certificate and provisioning profile from secrets
|
||||
echo -n "$BUILD_CERTIFICATE_BASE64" | base64 --decode -o $CERTIFICATE_PATH
|
||||
|
||||
# create temporary keychain
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH
|
||||
security set-keychain-settings -lut 21600 $KEYCHAIN_PATH
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH
|
||||
|
||||
# import certificate to keychain
|
||||
security import $CERTIFICATE_PATH -P "$P12_PASSWORD" -A -t cert -f pkcs12 -k $KEYCHAIN_PATH
|
||||
security list-keychain -d user -s $KEYCHAIN_PATH
|
||||
|
||||
# store notarytool credentials
|
||||
xcrun notarytool store-credentials --apple-id "$APPLE_ID" --password "$APPLE_APP_PASSWORD" --team-id "$APPLE_TEAM_ID" notarytool --validate
|
||||
|
||||
uv run make bundle_mac
|
||||
poetry run make bundle_mac
|
||||
|
||||
elif [ "$RUNNER_OS" == "Windows" ]; then
|
||||
|
||||
cp -r ./dll_backup ./buzz/
|
||||
uv run make bundle_windows
|
||||
poetry run make bundle_windows
|
||||
|
||||
fi
|
||||
env:
|
||||
|
|
@ -264,54 +203,43 @@ jobs:
|
|||
name: Buzz-${{ runner.os }}-${{ runner.arch }}
|
||||
path: |
|
||||
dist/Buzz*-windows.exe
|
||||
dist/Buzz*-windows-*.bin
|
||||
dist/Buzz*-mac.dmg
|
||||
|
||||
build_wheels:
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
BUZZ_DISABLE_TELEMETRY: true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-15-intel, macos-latest]
|
||||
# macos-13 is an intel runner, macos-14 is apple silicon
|
||||
os: [ubuntu-latest, windows-latest, macos-13, macos-14]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
# Should be removed with next update to whisper.cpp
|
||||
- name: Downgrade Xcode
|
||||
uses: maxim-lobanov/setup-xcode@v1
|
||||
with:
|
||||
xcode-version: '16.0.0'
|
||||
if: matrix.os == 'macos-latest'
|
||||
|
||||
- name: Install Vulkan SDK
|
||||
if: "startsWith(matrix.os, 'ubuntu-') || matrix.os == 'windows-latest'"
|
||||
uses: humbletim/install-vulkan-sdk@v1.2
|
||||
with:
|
||||
version: 1.4.309.0
|
||||
cache: true
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
- name: Copy Windows DLLs
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
cp -r ./dll_backup ./buzz/
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Build wheels
|
||||
run: uv build --wheel
|
||||
shell: bash
|
||||
uses: pypa/cibuildwheel@v2.19.2
|
||||
env:
|
||||
CIBW_ARCHS_WINDOWS: "auto"
|
||||
CIBW_ARCHS_MACOS: "universal2"
|
||||
CIBW_ARCHS_LINUX: "auto"
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: buzz-wheel-${{ runner.os }}-${{ runner.arch }}
|
||||
path: ./dist/*.whl
|
||||
name: cibw-wheels-${{ matrix.os }}
|
||||
path: ./wheelhouse/*.whl
|
||||
|
||||
publish_pypi:
|
||||
needs: [build_wheels, test]
|
||||
needs: [ build_wheels, test ]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUZZ_DISABLE_TELEMETRY: true
|
||||
environment: pypi
|
||||
permissions:
|
||||
id-token: write
|
||||
|
|
@ -319,7 +247,7 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: buzz-wheel-*
|
||||
pattern: cibw-*
|
||||
path: dist
|
||||
merge-multiple: true
|
||||
|
||||
|
|
@ -330,16 +258,14 @@ jobs:
|
|||
|
||||
release:
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
BUZZ_DISABLE_TELEMETRY: true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: macos-15-intel
|
||||
- os: macos-13
|
||||
- os: macos-latest
|
||||
- os: windows-latest
|
||||
needs: [build, test]
|
||||
needs: [ build, test ]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
|
@ -351,52 +277,39 @@ jobs:
|
|||
name: Buzz-${{ runner.os }}-${{ runner.arch }}
|
||||
|
||||
- name: Rename .dmg files
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
for file in Buzz*.dmg; do
|
||||
mv "$file" "${file%.dmg}-${{ runner.arch }}.dmg"
|
||||
done
|
||||
|
||||
- name: Install Poetry Action
|
||||
uses: snok/install-poetry@v1.3.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: |
|
||||
Buzz*-unix.tar.gz
|
||||
Buzz*.exe
|
||||
Buzz*.bin
|
||||
Buzz*.dmg
|
||||
Buzz*-windows.exe
|
||||
Buzz*-mac.dmg
|
||||
|
||||
# Brew Cask deployment fails and the app is deprecated on Brew.
|
||||
# deploy_brew_cask:
|
||||
# runs-on: macos-latest
|
||||
# env:
|
||||
# BUZZ_DISABLE_TELEMETRY: true
|
||||
# needs: [release]
|
||||
# if: startsWith(github.ref, 'refs/tags/')
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# submodules: recursive
|
||||
#
|
||||
# # Should be removed with next update to whisper.cpp
|
||||
# - name: Downgrade Xcode
|
||||
# uses: maxim-lobanov/setup-xcode@v1
|
||||
# with:
|
||||
# xcode-version: '16.0.0'
|
||||
# if: matrix.os == 'macos-latest'
|
||||
#
|
||||
# - name: Install uv
|
||||
# uses: astral-sh/setup-uv@v6
|
||||
#
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: "3.12"
|
||||
#
|
||||
# - name: Install dependencies
|
||||
# run: uv sync
|
||||
#
|
||||
# - name: Upload to Brew
|
||||
# run: uv run make upload_brew
|
||||
# env:
|
||||
# HOMEBREW_GITHUB_API_TOKEN: ${{ secrets.HOMEBREW_GITHUB_API_TOKEN }}
|
||||
deploy_brew_cask:
|
||||
runs-on: macos-latest
|
||||
needs: [ release ]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install Poetry Action
|
||||
uses: snok/install-poetry@v1.3.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
- name: Upload to Brew
|
||||
run: make upload_brew
|
||||
env:
|
||||
HOMEBREW_GITHUB_API_TOKEN: ${{ secrets.HOMEBREW_GITHUB_API_TOKEN }}
|
||||
|
|
|
|||
4
.github/workflows/manual-build.yml
vendored
|
|
@ -9,8 +9,6 @@ concurrency:
|
|||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
BUZZ_DISABLE_TELEMETRY: true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
|
@ -71,8 +69,6 @@ jobs:
|
|||
|
||||
build-snap:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUZZ_DISABLE_TELEMETRY: true
|
||||
outputs:
|
||||
snap: ${{ steps.snapcraft.outputs.snap }}
|
||||
steps:
|
||||
|
|
|
|||
59
.github/workflows/snapcraft.yml
vendored
|
|
@ -14,58 +14,27 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
BUZZ_DISABLE_TELEMETRY: true
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
snap: ${{ steps.snapcraft.outputs.snap }}
|
||||
steps:
|
||||
# Ideas from https://github.com/orgs/community/discussions/25678
|
||||
- name: Remove unused build tools
|
||||
run: |
|
||||
sudo apt-get remove -y azure-cli google-cloud-sdk hhvm google-chrome-stable firefox powershell mono-devel || true
|
||||
sudo apt-get autoremove -y
|
||||
sudo apt-get clean
|
||||
python -m pip cache purge
|
||||
rm -rf /opt/hostedtoolcache || true
|
||||
- name: Check available disk space
|
||||
run: |
|
||||
echo "=== Disk space ==="
|
||||
df -h
|
||||
echo "=== Memory ==="
|
||||
free -h
|
||||
- name: Maximize build space
|
||||
uses: easimon/maximize-build-space@master
|
||||
with:
|
||||
root-reserve-mb: 20000
|
||||
swap-size-mb: 1024
|
||||
remove-dotnet: 'true'
|
||||
remove-android: 'true'
|
||||
remove-haskell: 'true'
|
||||
remove-codeql: 'true'
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install Snapcraft and dependencies
|
||||
run: |
|
||||
set -x
|
||||
# Ensure snapd is ready
|
||||
sudo systemctl start snapd.socket
|
||||
sudo snap wait system seed.loaded
|
||||
|
||||
echo "=== Installing snapcraft ==="
|
||||
sudo snap install --classic snapcraft
|
||||
|
||||
echo "=== Installing gnome extension dependencies ==="
|
||||
sudo snap install gnome-46-2404 || { echo "Failed to install gnome-46-2404"; sudo journalctl -u snapd --no-pager -n 50; exit 1; }
|
||||
sudo snap install gnome-46-2404-sdk || { echo "Failed to install gnome-46-2404-sdk"; sudo journalctl -u snapd --no-pager -n 50; exit 1; }
|
||||
|
||||
echo "=== Installing build-snaps ==="
|
||||
sudo snap install --classic astral-uv || { echo "Failed to install astral-uv"; sudo journalctl -u snapd --no-pager -n 50; exit 1; }
|
||||
|
||||
echo "=== Installed snaps ==="
|
||||
snap list
|
||||
- name: Check disk space before build
|
||||
run: df -h
|
||||
- name: Build snap
|
||||
- uses: snapcore/action-build@v1
|
||||
id: snapcraft
|
||||
env:
|
||||
SNAPCRAFT_BUILD_ENVIRONMENT: host
|
||||
run: |
|
||||
sudo -E snapcraft pack --verbose --destructive-mode
|
||||
echo "snap=$(ls *.snap)" >> $GITHUB_OUTPUT
|
||||
- run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libportaudio2 libtbb-dev
|
||||
- run: sudo snap install --devmode *.snap
|
||||
- run: |
|
||||
cd $HOME
|
||||
|
|
|
|||
17
.gitignore
vendored
|
|
@ -5,21 +5,19 @@ build/
|
|||
.coverage*
|
||||
!.coveragerc
|
||||
.env
|
||||
.DS_Store
|
||||
htmlcov/
|
||||
coverage.xml
|
||||
.idea/
|
||||
.venv/
|
||||
venv/
|
||||
.claude/
|
||||
|
||||
# whisper_cpp
|
||||
libwhisper.*
|
||||
libwhisper-coreml.*
|
||||
whisper_cpp
|
||||
*.exe
|
||||
*.dll
|
||||
*.dylib
|
||||
*.so
|
||||
buzz/whisper_cpp/*
|
||||
whisper_cpp.exe
|
||||
whisper.dll
|
||||
buzz/whisper_cpp.py
|
||||
buzz/whisper_cpp_coreml.py
|
||||
|
||||
# Internationalization - compiled binaries
|
||||
*.mo
|
||||
|
|
@ -31,6 +29,3 @@ benchmarks.json
|
|||
*.egg-info
|
||||
/coverage/
|
||||
/wheelhouse/
|
||||
/.flatpak-builder
|
||||
/repo
|
||||
/nemo_msdd_configs
|
||||
|
|
|
|||
12
.gitmodules
vendored
|
|
@ -1,15 +1,3 @@
|
|||
[submodule "whisper.cpp"]
|
||||
path = whisper.cpp
|
||||
url = https://github.com/ggerganov/whisper.cpp
|
||||
[submodule "whisper_diarization"]
|
||||
path = whisper_diarization
|
||||
url = https://github.com/MahmoudAshraf97/whisper-diarization
|
||||
[submodule "demucs_repo"]
|
||||
path = demucs_repo
|
||||
url = https://github.com/MahmoudAshraf97/demucs.git
|
||||
[submodule "deepmultilingualpunctuation"]
|
||||
path = deepmultilingualpunctuation
|
||||
url = https://github.com/oliverguhr/deepmultilingualpunctuation.git
|
||||
[submodule "ctc_forced_aligner"]
|
||||
path = ctc_forced_aligner
|
||||
url = https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
3.12
|
||||
85
Buzz.spec
|
|
@ -1,5 +1,4 @@
|
|||
# -*- mode: python ; coding: utf-8 -*-
|
||||
import os
|
||||
import os.path
|
||||
import platform
|
||||
import shutil
|
||||
|
|
@ -10,7 +9,6 @@ from buzz.__version__ import VERSION
|
|||
|
||||
datas = []
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("demucs")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("regex")
|
||||
|
|
@ -22,34 +20,12 @@ datas += copy_metadata("tokenizers")
|
|||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("julius")
|
||||
datas += copy_metadata("openunmix")
|
||||
datas += copy_metadata("lameenc")
|
||||
datas += copy_metadata("diffq")
|
||||
datas += copy_metadata("einops")
|
||||
datas += copy_metadata("hydra-core")
|
||||
datas += copy_metadata("hydra-colorlog")
|
||||
datas += copy_metadata("museval")
|
||||
datas += copy_metadata("submitit")
|
||||
datas += copy_metadata("treetable")
|
||||
datas += copy_metadata("soundfile")
|
||||
datas += copy_metadata("dora-search")
|
||||
datas += copy_metadata("lhotse")
|
||||
|
||||
# Allow transformers package to load __init__.py file dynamically:
|
||||
# https://github.com/chidiwilliams/buzz/issues/272
|
||||
datas += collect_data_files("transformers", include_py_files=True)
|
||||
|
||||
datas += collect_data_files("faster_whisper", include_py_files=True)
|
||||
datas += collect_data_files("stable_whisper", include_py_files=True)
|
||||
datas += collect_data_files("whisper")
|
||||
datas += collect_data_files("demucs", include_py_files=True)
|
||||
datas += collect_data_files("whisper_diarization", include_py_files=True)
|
||||
datas += collect_data_files("deepmultilingualpunctuation", include_py_files=True)
|
||||
datas += collect_data_files("ctc_forced_aligner", include_py_files=True, excludes=["build"])
|
||||
datas += collect_data_files("nemo", include_py_files=True)
|
||||
datas += collect_data_files("lightning_fabric", include_py_files=True)
|
||||
datas += collect_data_files("pytorch_lightning", include_py_files=True)
|
||||
datas += [("buzz/assets/*", "assets")]
|
||||
datas += [("buzz/locale", "locale")]
|
||||
datas += [("buzz/schema.sql", ".")]
|
||||
|
|
@ -62,65 +38,32 @@ if DEBUG:
|
|||
else:
|
||||
options = []
|
||||
|
||||
def find_dependency(name: str) -> str:
|
||||
paths = os.environ["PATH"].split(os.pathsep)
|
||||
candidates = []
|
||||
for path in paths:
|
||||
exe_path = os.path.join(path, name)
|
||||
if os.path.isfile(exe_path):
|
||||
candidates.append(exe_path)
|
||||
binaries = [
|
||||
(
|
||||
"buzz/whisper.dll" if platform.system() == "Windows" else "buzz/libwhisper.*",
|
||||
".",
|
||||
),
|
||||
(shutil.which("ffmpeg"), "."),
|
||||
(shutil.which("ffprobe"), "."),
|
||||
]
|
||||
|
||||
# Check for chocolatery shims
|
||||
shim_path = os.path.normpath(os.path.join(path, "..", "lib", "ffmpeg", "tools", "ffmpeg", "bin", name))
|
||||
if os.path.isfile(shim_path):
|
||||
candidates.append(shim_path)
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Pick the largest file
|
||||
return max(candidates, key=lambda f: os.path.getsize(f))
|
||||
|
||||
if platform.system() == "Windows":
|
||||
binaries = [
|
||||
(find_dependency("ffmpeg.exe"), "."),
|
||||
(find_dependency("ffprobe.exe"), "."),
|
||||
]
|
||||
else:
|
||||
binaries = [
|
||||
(shutil.which("ffmpeg"), "."),
|
||||
(shutil.which("ffprobe"), "."),
|
||||
]
|
||||
|
||||
binaries.append(("buzz/whisper_cpp/*", "buzz/whisper_cpp"))
|
||||
# Include libwhisper-coreml.dylib on Apple Silicon
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
binaries.append(("buzz/libwhisper-coreml.dylib", "."))
|
||||
|
||||
# Include dll_backup folder and its contents on Windows
|
||||
if platform.system() == "Windows":
|
||||
datas += [("dll_backup", "dll_backup")]
|
||||
datas += collect_data_files("msvc-runtime")
|
||||
|
||||
binaries.append(("dll_backup/SDL2.dll", "dll_backup"))
|
||||
binaries.append(("dll_backup/whisper.dll", "dll_backup"))
|
||||
|
||||
a = Analysis(
|
||||
["main.py"],
|
||||
pathex=[],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=[
|
||||
"dora", "dora.log",
|
||||
"julius", "julius.core", "julius.resample",
|
||||
"openunmix", "openunmix.filtering",
|
||||
"lameenc",
|
||||
"diffq",
|
||||
"einops",
|
||||
"hydra", "hydra.core", "hydra.core.global_hydra",
|
||||
"hydra_colorlog",
|
||||
"museval",
|
||||
"submitit",
|
||||
"treetable",
|
||||
"soundfile",
|
||||
"_soundfile_data",
|
||||
"lhotse",
|
||||
],
|
||||
hiddenimports=[],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
- Use uv to run tests and any scripts
|
||||
86
CONTRIBUTING.md
Executable file → Normal file
|
|
@ -28,8 +28,7 @@ What version of the Buzz are you using? On what OS? What are steps to reproduce
|
|||
**Logs**
|
||||
|
||||
Log files contain valuable information about what the Buzz was doing before the issue occurred. You can get the logs like this:
|
||||
* Linux run the app from the terminal and check the output.
|
||||
* Mac get logs from `~/Library/Logs/Buzz`.
|
||||
* Mac and Linux run the app from the terminal and check the output.
|
||||
* Windows paste this into the Windows Explorer address bar `%USERPROFILE%\AppData\Local\Buzz\Buzz\Logs` and check the logs file.
|
||||
|
||||
**Test on latest version**
|
||||
|
|
@ -46,15 +45,16 @@ Linux versions get also pushed to the snap. To install latest development versio
|
|||
|
||||
1. Clone the repository `git clone --recursive https://github.com/chidiwilliams/buzz.git`
|
||||
2. Enter repo folder `cd buzz`
|
||||
3. Install uv `curl -LsSf https://astral.sh/uv/install.sh | sh` (or see [uv installation docs](https://docs.astral.sh/uv/getting-started/installation/))
|
||||
4. Install system dependencies you may be missing
|
||||
3. Install Poetry `sudo apt-get install python3-poetry`
|
||||
4. Activate the virtual environment `poetry shell`
|
||||
5. Install the dependencies `poetry install`
|
||||
6. Install system dependencies you may be missing
|
||||
```
|
||||
sudo apt-get install --no-install-recommends libyaml-dev libtbb-dev libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-shape0 libxcb-cursor0 libportaudio2 gettext libpulse0 ffmpeg
|
||||
```
|
||||
On versions prior to Ubuntu 24.04 install `sudo apt-get install --no-install-recommends libegl1-mesa`
|
||||
|
||||
5. Install the dependencies `uv sync`
|
||||
6. Run Buzz `uv run buzz`
|
||||
7. Build Buzz `poetry build`
|
||||
8. Run Buzz `python -m buzz`
|
||||
|
||||
#### Necessary dependencies for Faster Whisper on GPU
|
||||
|
||||
|
|
@ -64,24 +64,23 @@ On versions prior to Ubuntu 24.04 install `sudo apt-get install --no-install-rec
|
|||
#### Error for Faster Whisper on GPU `Could not load library libcudnn_ops_infer.so.8`
|
||||
|
||||
You need to add path to the library to the `LD_LIBRARY_PATH` environment variable.
|
||||
Check exact path to your uv virtual environment, it may be different for you.
|
||||
Check exact path to your poetry virtual environment, it may be different for you.
|
||||
|
||||
```
|
||||
export LD_LIBRARY_PATH=/path/to/buzz/.venv/lib/python3.12/site-packages/nvidia/cudnn/lib/:$LD_LIBRARY_PATH
|
||||
export LD_LIBRARY_PATH=/home/PutYourUserNameHere/.cache/pypoetry/virtualenvs/buzz-captions-JjGFxAW6-py3.12/lib/python3.12/site-packages/nvidia/cudnn/lib/:$LD_LIBRARY_PATH
|
||||
```
|
||||
|
||||
#### For Whisper.cpp you will need to install Vulkan SDK
|
||||
|
||||
Follow the instructions for your distribution https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
|
||||
|
||||
### Mac
|
||||
|
||||
1. Clone the repository `git clone --recursive https://github.com/chidiwilliams/buzz.git`
|
||||
2. Enter repo folder `cd buzz`
|
||||
3. Install uv `curl -LsSf https://astral.sh/uv/install.sh | sh` (or `brew install uv`)
|
||||
4. Install system dependencies you may be missing `brew install ffmpeg`
|
||||
5. Install the dependencies `uv sync`
|
||||
6. Run Buzz `uv run buzz`
|
||||
3. Install Poetry `brew install poetry`
|
||||
4. Activate the virtual environment `poetry shell`
|
||||
5. Install the dependencies `poetry install`
|
||||
6. Install system dependencies you may be missing `brew install ffmpeg`
|
||||
7. Build Buzz `poetry build`
|
||||
8. Run Buzz `python -m buzz`
|
||||
|
||||
|
||||
|
||||
|
|
@ -93,31 +92,52 @@ Assumes you have [Git](https://git-scm.com/downloads) and [python](https://www.p
|
|||
```
|
||||
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
|
||||
```
|
||||
2. Install the build tools. `choco install make cmake`
|
||||
2. Install the GNU make. `choco install make`
|
||||
3. Install the ffmpeg. `choco install ffmpeg`
|
||||
4. Download [Build Tools for Visual Studio 2022](https://visualstudio.microsoft.com/vs/older-downloads/) and install "Desktop development with C++" workload.
|
||||
5. Add location of `namke` to your PATH environment variable. Usually it is `C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\bin\Hostx64\x86`
|
||||
6. Install Vulkan SDK from https://vulkan.lunarg.com/sdk/home
|
||||
7. Clone the repository `git clone --recursive https://github.com/chidiwilliams/buzz.git`
|
||||
8. Enter repo folder `cd buzz`
|
||||
9. Install uv `powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"`
|
||||
10. Install the dependencies `uv sync`
|
||||
11. Build Whisper.cpp `uv run make buzz/whisper_cpp`
|
||||
12. `cp -r .\dll_backup\ .\buzz\`
|
||||
13. Run Buzz `uv run buzz`
|
||||
4. Install [MSYS2](https://www.msys2.org/), follow [this guide](https://sajidifti.medium.com/how-to-install-gcc-and-gdb-on-windows-using-msys2-tutorial-0fceb7e66454).
|
||||
5. Install Poetry, paste this info Windows PowerShell line by line. [More info](https://python-poetry.org/docs/)
|
||||
```
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | py -
|
||||
|
||||
Note: It should be safe to ignore any "syntax errors" you see during the build. Buzz will work. Also you can ignore any errors for FFmpeg. Buzz tries to load FFmpeg by several different means and some of them throw errors, but FFmpeg should eventually be found and work.
|
||||
[Environment]::SetEnvironmentVariable("Path", $env:Path + ";%APPDATA%\pypoetry\venv\Scripts", "User")
|
||||
|
||||
Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
||||
```
|
||||
6. Add poetry to PATH. `%APPDATA%\Python\Scripts`
|
||||
7. Restart Windows.
|
||||
8. Clone the repository `git clone --recursive https://github.com/chidiwilliams/buzz.git`
|
||||
9. Enter repo folder `cd buzz`
|
||||
10. Activate the virtual environment `poetry shell`
|
||||
11. Install the dependencies `poetry install`
|
||||
12. `cp -r .\dll_backup\ .\buzz\`
|
||||
13. Build Buzz `poetry build`
|
||||
14. Run Buzz `python -m buzz`
|
||||
|
||||
#### GPU Support
|
||||
|
||||
GPU support on Windows with Nvidia GPUs is included out of the box in the `.exe` installer.
|
||||
GPU support on Windows is possible for Buzz that ir installed from the source code or with `pip`.
|
||||
Use the instructions above to install the Buzz from the source code or run `pip install buzz-captions`
|
||||
and then follow the instructions below to enable CUDA GPU support. For pip installation it is recommended to use
|
||||
a separate [virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
To add GPU support for source or `pip` installed version switch torch library to GPU version. For more info see https://pytorch.org/get-started/locally/ .
|
||||
To enable GPU support first ensure CUDA 12.1 is installed - https://developer.nvidia.com/cuda-12-1-0-download-archive
|
||||
Other versions of CUDA 12 should also work.
|
||||
|
||||
Switch torch library to GPU version. It must match the CUDA version installed, see https://pytorch.org/get-started/locally/ .
|
||||
```
|
||||
uv add --index https://download.pytorch.org/whl/cu128 torch==2.7.1+cu128 torchaudio==2.7.1+cu128
|
||||
uv add --index https://pypi.ngc.nvidia.com nvidia-cublas-cu12==12.8.3.14 nvidia-cuda-cupti-cu12==12.8.57 nvidia-cuda-nvrtc-cu12==12.8.61 nvidia-cuda-runtime-cu12==12.8.57 nvidia-cudnn-cu12==9.7.1.26 nvidia-cufft-cu12==11.3.3.41 nvidia-curand-cu12==10.3.9.55 nvidia-cusolver-cu12==11.7.2.55 nvidia-cusparse-cu12==12.5.4.2 nvidia-cusparselt-cu12==0.6.3 nvidia-nvjitlink-cu12==12.8.61 nvidia-nvtx-cu12==12.8.55
|
||||
pip3 uninstall torch torchaudio
|
||||
pip3 install torch==2.2.1+cu121 torchaudio==2.2.1+cu121 --index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
To use Faster Whisper on GPU, install the following libraries:
|
||||
* [cuBLAS](https://developer.nvidia.com/cublas)
|
||||
* [cuDNN](https://developer.nvidia.com/cudnn)
|
||||
* [cuDNN](https://developer.nvidia.com/cudnn)
|
||||
|
||||
Ensure ffmpeg dependencies are installed
|
||||
```
|
||||
pip3 uninstall ffmpeg ffmpeg-python
|
||||
pip3 install ffmpeg
|
||||
pip3 install ffmpeg-python
|
||||
```
|
||||
|
||||
Run Buzz `python -m buzz`
|
||||
183
Makefile
|
|
@ -1,109 +1,112 @@
|
|||
# Change also in pyproject.toml and buzz/__version__.py
|
||||
version := 1.4.4
|
||||
version := $$(poetry version -s)
|
||||
version_escaped := $$(echo ${version} | sed -e 's/\./\\./g')
|
||||
|
||||
mac_app_path := ./dist/Buzz.app
|
||||
mac_zip_path := ./dist/Buzz-${version}-mac.zip
|
||||
mac_dmg_path := ./dist/Buzz-${version}-mac.dmg
|
||||
|
||||
bundle_windows: dist/Buzz
|
||||
iscc installer.iss
|
||||
iscc //DAppVersion=${version} installer.iss
|
||||
|
||||
bundle_mac: dist/Buzz.app codesign_all_mac zip_mac notarize_zip staple_app_mac dmg_mac
|
||||
|
||||
bundle_mac_unsigned: dist/Buzz.app zip_mac dmg_mac_unsigned
|
||||
|
||||
UNAME_S := $(shell uname -s)
|
||||
UNAME_M := $(shell uname -m)
|
||||
|
||||
LIBWHISPER :=
|
||||
ifeq ($(OS), Windows_NT)
|
||||
LIBWHISPER=whisper.dll
|
||||
else
|
||||
ifeq ($(UNAME_S), Darwin)
|
||||
LIBWHISPER=libwhisper.dylib
|
||||
else
|
||||
LIBWHISPER=libwhisper.so
|
||||
endif
|
||||
endif
|
||||
|
||||
clean:
|
||||
ifeq ($(OS), Windows_NT)
|
||||
-rmdir /s /q buzz\whisper_cpp
|
||||
-rmdir /s /q whisper.cpp\build
|
||||
-rmdir /s /q dist
|
||||
-Remove-Item -Recurse -Force buzz\whisper_cpp
|
||||
-Remove-Item -Recurse -Force whisper.cpp\build
|
||||
-Remove-Item -Recurse -Force dist\*
|
||||
-rm -rf buzz/whisper_cpp
|
||||
-rm -rf whisper.cpp/build
|
||||
-rm -rf dist/*
|
||||
-rm -rf buzz/__pycache__ buzz/**/__pycache__ buzz/**/**/__pycache__ buzz/**/**/**/__pycache__
|
||||
-for /d /r buzz %%d in (__pycache__) do @if exist "%%d" rmdir /s /q "%%d"
|
||||
-del /f buzz\$(LIBWHISPER) 2> nul
|
||||
-del /f buzz\whisper_cpp.py 2> nul
|
||||
-rmdir /s /q whisper.cpp\build 2> nul
|
||||
-rmdir /s /q dist 2> nul
|
||||
-rm -f buzz/$(LIBWHISPER)
|
||||
-rm -f buzz/whisper_cpp.py
|
||||
-rm -rf whisper.cpp/build || true
|
||||
-rm -rf dist/* || true
|
||||
else
|
||||
rm -rf buzz/whisper_cpp || true
|
||||
rm -f buzz/$(LIBWHISPER)
|
||||
rm -f buzz/whisper_cpp.py
|
||||
rm -f buzz/libwhisper-coreml.dylib || true
|
||||
rm -f buzz/whisper_cpp_coreml.py || true
|
||||
rm -rf whisper.cpp/build || true
|
||||
rm -rf dist/* || true
|
||||
find buzz -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
endif
|
||||
|
||||
COVERAGE_THRESHOLD := 70
|
||||
COVERAGE_THRESHOLD := 75
|
||||
|
||||
test: buzz/whisper_cpp
|
||||
# A check to get updates of yt-dlp. Should run only on local as part of regular development operations
|
||||
# Sort of a local "update checker"
|
||||
ifndef CI
|
||||
uv lock --upgrade-package yt-dlp
|
||||
endif
|
||||
pytest -s -vv --cov=buzz --cov-report=xml --cov-report=html --benchmark-skip --cov-fail-under=${COVERAGE_THRESHOLD} --cov-config=.coveragerc
|
||||
test: buzz/whisper_cpp.py translation_mo
|
||||
pytest -s -vv --cov=buzz --cov-report=xml --cov-report=html --benchmark-skip --cov-fail-under=${COVERAGE_THRESHOLD}
|
||||
|
||||
benchmarks: buzz/whisper_cpp
|
||||
benchmarks: buzz/whisper_cpp.py translation_mo
|
||||
pytest -s -vv --benchmark-only --benchmark-json benchmarks.json
|
||||
|
||||
dist/Buzz dist/Buzz.app: buzz/whisper_cpp
|
||||
dist/Buzz dist/Buzz.app: buzz/whisper_cpp.py translation_mo
|
||||
pyinstaller --noconfirm Buzz.spec
|
||||
|
||||
version:
|
||||
poetry version ${version}
|
||||
echo "VERSION = \"${version}\"" > buzz/__version__.py
|
||||
|
||||
buzz/whisper_cpp: translation_mo
|
||||
ifeq ($(OS), Windows_NT)
|
||||
# Build Whisper with Vulkan support.
|
||||
# The _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR is needed to prevent mutex lock issues on Windows
|
||||
# https://github.com/actions/runner-images/issues/10004#issuecomment-2156109231
|
||||
# -DCMAKE_[C|CXX]_COMPILER_WORKS=TRUE is used to prevent issue in building test program that fails on CI
|
||||
# GGML_NATIVE=OFF ensures we don't use -march=native (which would target the build machine's CPU)
|
||||
cmake -S whisper.cpp -B whisper.cpp/build/ -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=OFF -DCMAKE_INSTALL_RPATH='$$ORIGIN' -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON -DCMAKE_C_FLAGS="-D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR" -DCMAKE_CXX_FLAGS="-D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR" -DCMAKE_C_COMPILER_WORKS=TRUE -DCMAKE_CXX_COMPILER_WORKS=TRUE -DGGML_VULKAN=1 -DGGML_NATIVE=OFF
|
||||
cmake --build whisper.cpp/build -j --config Release --verbose
|
||||
|
||||
-mkdir buzz/whisper_cpp
|
||||
cp whisper.cpp/build/bin/Release/whisper-cli.exe buzz/whisper_cpp/
|
||||
cp whisper.cpp/build/bin/Release/whisper-server.exe buzz/whisper_cpp/
|
||||
cp dll_backup/SDL2.dll buzz/whisper_cpp
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "if (-not (Test-Path 'buzz\whisper_cpp\ggml-silero-v6.2.0.bin')) { Start-BitsTransfer -Source https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v6.2.0.bin -Destination 'buzz\whisper_cpp\ggml-silero-v6.2.0.bin' }"
|
||||
endif
|
||||
|
||||
ifeq ($(shell uname -s), Linux)
|
||||
# Build Whisper with Vulkan support
|
||||
# GGML_NATIVE=OFF ensures we don't use -march=native (which would target the build machine's CPU)
|
||||
# This enables portable SSE4.2/AVX/AVX2 optimizations that work on most x86_64 CPUs
|
||||
rm -rf whisper.cpp/build || true
|
||||
-mkdir -p buzz/whisper_cpp
|
||||
cmake -S whisper.cpp -B whisper.cpp/build/ -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON -DCMAKE_INSTALL_RPATH='$$ORIGIN' -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON -DGGML_VULKAN=1 -DGGML_NATIVE=OFF
|
||||
cmake --build whisper.cpp/build -j --config Release --verbose
|
||||
cp whisper.cpp/build/bin/whisper-cli buzz/whisper_cpp/ || true
|
||||
cp whisper.cpp/build/bin/whisper-server buzz/whisper_cpp/ || true
|
||||
cp -P whisper.cpp/build/src/libwhisper.so* buzz/whisper_cpp/ || true
|
||||
cp -P whisper.cpp/build/ggml/src/libggml.so* buzz/whisper_cpp/ || true
|
||||
cp -P whisper.cpp/build/ggml/src/libggml-base.so* buzz/whisper_cpp/ || true
|
||||
cp -P whisper.cpp/build/ggml/src/libggml-cpu.so* buzz/whisper_cpp/ || true
|
||||
cp -P whisper.cpp/build/ggml/src/ggml-vulkan/libggml-vulkan.so* buzz/whisper_cpp/ || true
|
||||
test -f buzz/whisper_cpp/ggml-silero-v6.2.0.bin || curl -L -o buzz/whisper_cpp/ggml-silero-v6.2.0.bin https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v6.2.0.bin
|
||||
endif
|
||||
|
||||
# Build on Macs
|
||||
ifeq ($(shell uname -s), Darwin)
|
||||
-rm -rf whisper.cpp/build || true
|
||||
-mkdir -p buzz/whisper_cpp
|
||||
|
||||
ifeq ($(shell uname -m), arm64)
|
||||
cmake -S whisper.cpp -B whisper.cpp/build/ -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON -DWHISPER_COREML=1
|
||||
CMAKE_FLAGS=
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
AVX1_M := $(shell sysctl machdep.cpu.features)
|
||||
ifeq (,$(findstring AVX1.0,$(AVX1_M)))
|
||||
CMAKE_FLAGS += -DWHISPER_NO_AVX=ON
|
||||
endif
|
||||
ifeq (,$(findstring FMA,$(AVX1_M)))
|
||||
CMAKE_FLAGS += -DWHISPER_NO_FMA=ON
|
||||
endif
|
||||
AVX2_M := $(shell sysctl machdep.cpu.leaf7_features)
|
||||
ifeq (,$(findstring AVX2,$(AVX2_M)))
|
||||
CMAKE_FLAGS += -DWHISPER_NO_AVX2=ON
|
||||
endif
|
||||
CMAKE_FLAGS += -DCMAKE_OSX_ARCHITECTURES="x86_64;arm64"
|
||||
else
|
||||
# Intel
|
||||
cmake -S whisper.cpp -B whisper.cpp/build/ -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON -DGGML_VULKAN=0 -DGGML_METAL=0
|
||||
ifeq ($(OS), Windows_NT)
|
||||
CMAKE_FLAGS += -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release
|
||||
endif
|
||||
endif
|
||||
|
||||
cmake --build whisper.cpp/build -j --config Release --verbose
|
||||
cp whisper.cpp/build/bin/whisper-cli buzz/whisper_cpp/ || true
|
||||
cp whisper.cpp/build/bin/whisper-server buzz/whisper_cpp/ || true
|
||||
cp whisper.cpp/build/src/libwhisper.dylib buzz/whisper_cpp/ || true
|
||||
cp whisper.cpp/build/ggml/src/libggml* buzz/whisper_cpp/ || true
|
||||
test -f buzz/whisper_cpp/ggml-silero-v6.2.0.bin || curl -L -o buzz/whisper_cpp/ggml-silero-v6.2.0.bin https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v6.2.0.bin
|
||||
buzz/$(LIBWHISPER):
|
||||
ifeq ($(OS), Windows_NT)
|
||||
cp dll_backup/whisper.dll buzz || copy dll_backup\whisper.dll buzz\whisper.dll
|
||||
cp dll_backup/SDL2.dll buzz || copy dll_backup\SDL2.dll buzz\SDL2.dll
|
||||
else
|
||||
cmake -S whisper.cpp -B whisper.cpp/build/ $(CMAKE_FLAGS)
|
||||
cmake --build whisper.cpp/build --verbose
|
||||
cp whisper.cpp/build/bin/Debug/$(LIBWHISPER) buzz || true
|
||||
cp whisper.cpp/build/$(LIBWHISPER) buzz || true
|
||||
endif
|
||||
# Build CoreML support on ARM Macs
|
||||
ifeq ($(shell uname -m), arm64)
|
||||
ifeq ($(shell uname -s), Darwin)
|
||||
rm -rf whisper.cpp/build || true
|
||||
cmake -S whisper.cpp -B whisper.cpp/build/ $(CMAKE_FLAGS) -DWHISPER_COREML=1
|
||||
cmake --build whisper.cpp/build --verbose
|
||||
cp whisper.cpp/build/bin/Debug/$(LIBWHISPER) buzz/libwhisper-coreml.dylib || true
|
||||
cp whisper.cpp/build/$(LIBWHISPER) buzz/libwhisper-coreml.dylib || true
|
||||
endif
|
||||
endif
|
||||
|
||||
buzz/whisper_cpp.py: buzz/$(LIBWHISPER) translation_mo
|
||||
cd buzz && ctypesgen ../whisper.cpp/whisper.h -lwhisper -o whisper_cpp.py
|
||||
ifeq ($(shell uname -m), arm64)
|
||||
ifeq ($(shell uname -s), Darwin)
|
||||
cd buzz && ctypesgen ../whisper.cpp/whisper.h -lwhisper-coreml -o whisper_cpp_coreml.py
|
||||
endif
|
||||
endif
|
||||
|
||||
# Prints all the Mac developer identities used for code signing
|
||||
|
|
@ -173,9 +176,11 @@ codesign_all_mac: dist/Buzz.app
|
|||
notarize_log:
|
||||
xcrun notarytool log ${id} --keychain-profile "$$BUZZ_KEYCHAIN_NOTARY_PROFILE"
|
||||
|
||||
VENV_PATH := $(shell poetry env info -p)
|
||||
|
||||
# Make GGML model from whisper. Example: make ggml model_path=/Users/chidiwilliams/.cache/whisper/medium.pt
|
||||
ggml:
|
||||
python3 ./whisper.cpp/models/convert-pt-to-ggml.py ${model_path} .venv/lib/python3.12/site-packages/whisper dist
|
||||
python3 ./whisper.cpp/models/convert-pt-to-ggml.py ${model_path} $(VENV_PATH)/src/whisper dist
|
||||
|
||||
upload_brew:
|
||||
brew bump-cask-pr --version ${version} --verbose buzz
|
||||
|
|
@ -197,30 +202,20 @@ gh_upgrade_pr:
|
|||
|
||||
translation_po_all:
|
||||
$(MAKE) translation_po locale=ca_ES
|
||||
$(MAKE) translation_po locale=da_DK
|
||||
$(MAKE) translation_po locale=de_DE
|
||||
$(MAKE) translation_po locale=en_US
|
||||
$(MAKE) translation_po locale=es_ES
|
||||
$(MAKE) translation_po locale=it_IT
|
||||
$(MAKE) translation_po locale=ja_JP
|
||||
$(MAKE) translation_po locale=lv_LV
|
||||
$(MAKE) translation_po locale=nl
|
||||
$(MAKE) translation_po locale=pl_PL
|
||||
$(MAKE) translation_po locale=pt_BR
|
||||
$(MAKE) translation_po locale=uk_UA
|
||||
$(MAKE) translation_po locale=zh_CN
|
||||
$(MAKE) translation_po locale=zh_TW
|
||||
$(MAKE) translation_po locale=it_IT
|
||||
$(MAKE) translation_po locale=lv_LV
|
||||
$(MAKE) translation_po locale=uk_UA
|
||||
$(MAKE) translation_po locale=ja_JP
|
||||
|
||||
TMP_POT_FILE_PATH := $(shell mktemp)
|
||||
PO_FILE_PATH := buzz/locale/${locale}/LC_MESSAGES/buzz.po
|
||||
translation_po:
|
||||
mkdir -p buzz/locale/${locale}/LC_MESSAGES
|
||||
xgettext --from-code=UTF-8 --add-location=file -o "${TMP_POT_FILE_PATH}" -l python $(shell find buzz -name '*.py')
|
||||
sed -i.bak 's/CHARSET/UTF-8/' ${TMP_POT_FILE_PATH}
|
||||
if [ ! -f ${PO_FILE_PATH} ]; then \
|
||||
msginit --no-translator --input=${TMP_POT_FILE_PATH} --output-file=${PO_FILE_PATH}; \
|
||||
fi
|
||||
rm ${TMP_POT_FILE_PATH}.bak
|
||||
xgettext --from-code=UTF-8 -o "${TMP_POT_FILE_PATH}" -l python $(shell find buzz -name '*.py')
|
||||
sed -i.bak 's/CHARSET/UTF-8/' ${TMP_POT_FILE_PATH} && rm ${TMP_POT_FILE_PATH}.bak
|
||||
msgmerge -U ${PO_FILE_PATH} ${TMP_POT_FILE_PATH}
|
||||
|
||||
# On windows we can have two ways to compile locales, one for CI the other for local builds
|
||||
|
|
@ -233,7 +228,7 @@ ifeq ($(OS), Windows_NT)
|
|||
done
|
||||
else
|
||||
for dir in buzz/locale/*/ ; do \
|
||||
python3 msgfmt.py -o $$dir/LC_MESSAGES/buzz.mo $$dir/LC_MESSAGES/buzz.po; \
|
||||
python msgfmt.py -o $$dir/LC_MESSAGES/buzz.mo $$dir/LC_MESSAGES/buzz.po; \
|
||||
done
|
||||
endif
|
||||
|
||||
|
|
|
|||
|
|
@ -1,98 +0,0 @@
|
|||
# Buzz
|
||||
|
||||
[ドキュメント](https://chidiwilliams.github.io/buzz/)
|
||||
|
||||
パソコン上でオフラインで音声の文字起こしと翻訳を行います。OpenAIの[Whisper](https://github.com/openai/whisper)を使用しています。
|
||||
|
||||

|
||||
[](https://github.com/chidiwilliams/buzz/actions/workflows/ci.yml)
|
||||
[](https://codecov.io/github/chidiwilliams/buzz)
|
||||

|
||||
[](https://GitHub.com/chidiwilliams/buzz/releases/)
|
||||
|
||||

|
||||
|
||||
## 機能
|
||||
- 音声・動画ファイルまたはYouTubeリンクの文字起こし
|
||||
- マイクからのリアルタイム音声文字起こし
|
||||
- イベントやプレゼンテーション中に便利なプレゼンテーションウィンドウ
|
||||
- ノイズの多い音声でより高い精度を得るための、文字起こし前の話者分離
|
||||
- 文字起こしメディアでの話者識別
|
||||
- 複数のWhisperバックエンドをサポート
|
||||
- Nvidia GPU向けCUDAアクセラレーション対応
|
||||
- Mac向けApple Silicon対応
|
||||
- Whisper.cppでのVulkanアクセラレーション対応(統合GPUを含むほとんどのGPUで利用可能)
|
||||
- TXT、SRT、VTT形式での文字起こしエクスポート
|
||||
- 検索、再生コントロール、速度調整機能を備えた高度な文字起こしビューア
|
||||
- 効率的なナビゲーションのためのキーボードショートカット
|
||||
- 新しいファイルの自動文字起こしのための監視フォルダ
|
||||
- スクリプトや自動化のためのコマンドラインインターフェース
|
||||
|
||||
## インストール
|
||||
|
||||
### macOS
|
||||
|
||||
[SourceForge](https://sourceforge.net/projects/buzz-captions/files/)から`.dmg`ファイルをダウンロードしてください。
|
||||
|
||||
### Windows
|
||||
|
||||
[SourceForge](https://sourceforge.net/projects/buzz-captions/files/)からインストールファイルを入手してください。
|
||||
|
||||
アプリは署名されていないため、インストール時に警告が表示されます。`詳細情報` -> `実行`を選択してください。
|
||||
|
||||
### Linux
|
||||
|
||||
Buzzは[Flatpak](https://flathub.org/apps/io.github.chidiwilliams.Buzz)または[Snap](https://snapcraft.io/buzz)として利用可能です。
|
||||
|
||||
Flatpakをインストールするには、以下を実行してください:
|
||||
```shell
|
||||
flatpak install flathub io.github.chidiwilliams.Buzz
|
||||
```
|
||||
|
||||
[](https://flathub.org/en/apps/io.github.chidiwilliams.Buzz)
|
||||
|
||||
Snapをインストールするには、以下を実行してください:
|
||||
```shell
|
||||
sudo apt-get install libportaudio2 libcanberra-gtk-module libcanberra-gtk3-module
|
||||
sudo snap install buzz
|
||||
```
|
||||
|
||||
[](https://snapcraft.io/buzz)
|
||||
|
||||
### PyPI
|
||||
|
||||
[ffmpeg](https://www.ffmpeg.org/download.html)をインストールしてください。
|
||||
|
||||
Python 3.12環境を使用していることを確認してください。
|
||||
|
||||
Buzzをインストール
|
||||
|
||||
```shell
|
||||
pip install buzz-captions
|
||||
python -m buzz
|
||||
```
|
||||
|
||||
**PyPIでのGPUサポート**
|
||||
|
||||
PyPIでインストールしたバージョンでWindows上のNvidia GPUのGPUサポートを有効にするには、[torch](https://pytorch.org/get-started/locally/)のCUDAサポートを確認してください。
|
||||
|
||||
```
|
||||
pip3 install -U torch==2.8.0+cu129 torchaudio==2.8.0+cu129 --index-url https://download.pytorch.org/whl/cu129
|
||||
pip3 install nvidia-cublas-cu12==12.9.1.4 nvidia-cuda-cupti-cu12==12.9.79 nvidia-cuda-runtime-cu12==12.9.79 --extra-index-url https://pypi.ngc.nvidia.com
|
||||
```
|
||||
|
||||
### 最新開発版
|
||||
|
||||
最新の機能やバグ修正を含む最新開発版の入手方法については、[FAQ](https://chidiwilliams.github.io/buzz/docs/faq#9-where-can-i-get-latest-development-version)をご覧ください。
|
||||
|
||||
### スクリーンショット
|
||||
|
||||
<div style="display: flex; flex-wrap: wrap;">
|
||||
<img alt="ファイルインポート" src="share/screenshots/buzz-1-import.png" style="max-width: 18%; margin-right: 1%;" />
|
||||
<img alt="メイン画面" src="share/screenshots/buzz-2-main_screen.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="設定" src="share/screenshots/buzz-3-preferences.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="モデル設定" src="share/screenshots/buzz-3.2-model-preferences.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="文字起こし" src="share/screenshots/buzz-4-transcript.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="ライブ録音" src="share/screenshots/buzz-5-live_recording.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="リサイズ" src="share/screenshots/buzz-6-resize.png" style="max-width: 18%;" />
|
||||
</div>
|
||||
96
README.md
|
|
@ -1,8 +1,6 @@
|
|||
[[简体中文](readme/README.zh_CN.md)] <- 点击查看中文页面。
|
||||
|
||||
# Buzz
|
||||
|
||||
[Documentation](https://chidiwilliams.github.io/buzz/)
|
||||
[Documentation](https://chidiwilliams.github.io/buzz/) | [Buzz Captions on the App Store](https://apps.apple.com/us/app/buzz-captions/id6446018936?mt=12&itsct=apps_box_badge&itscg=30200)
|
||||
|
||||
Transcribe and translate audio offline on your personal computer. Powered by
|
||||
OpenAI's [Whisper](https://github.com/openai/whisper).
|
||||
|
|
@ -13,94 +11,48 @@ OpenAI's [Whisper](https://github.com/openai/whisper).
|
|||

|
||||
[](https://GitHub.com/chidiwilliams/buzz/releases/)
|
||||
|
||||

|
||||
<blockquote>
|
||||
<p>Buzz is better on the App Store. Get a Mac-native version of Buzz with a cleaner look, audio playback, drag-and-drop import, transcript editing, search, and much more.</p>
|
||||
<a href="https://apps.apple.com/us/app/buzz-captions/id6446018936?mt=12&itsct=apps_box_badge&itscg=30200"><img src="https://toolbox.marketingtools.apple.com/api/badges/download-on-the-mac-app-store/black/en-us?size=250x83&releaseDate=1679529600" alt="Download on the Mac App Store" /></a>
|
||||
</blockquote>
|
||||
|
||||
## Features
|
||||
- Transcribe audio and video files or Youtube links
|
||||
- Live realtime audio transcription from microphone
|
||||
- Presentation window for easy accessibility during events and presentations
|
||||
- Speech separation before transcription for better accuracy on noisy audio
|
||||
- Speaker identification in transcribed media
|
||||
- Multiple whisper backend support
|
||||
- CUDA acceleration support for Nvidia GPUs
|
||||
- Apple Silicon support for Macs
|
||||
- Vulkan acceleration support for Whisper.cpp on most GPUs, including integrated GPUs
|
||||
- Export transcripts to TXT, SRT, and VTT
|
||||
- Advanced Transcription Viewer with search, playback controls, and speed adjustment
|
||||
- Keyboard shortcuts for efficient navigation
|
||||
- Watch folder for automatic transcription of new files
|
||||
- Command-Line Interface for scripting and automation
|
||||

|
||||
|
||||
## Installation
|
||||
|
||||
### macOS
|
||||
|
||||
Download the `.dmg` from the [SourceForge](https://sourceforge.net/projects/buzz-captions/files/).
|
||||
|
||||
### Windows
|
||||
|
||||
Get the installation files from the [SourceForge](https://sourceforge.net/projects/buzz-captions/files/).
|
||||
|
||||
App is not signed, you will get a warning when you install it. Select `More info` -> `Run anyway`.
|
||||
|
||||
### Linux
|
||||
|
||||
Buzz is available as a [Flatpak](https://flathub.org/apps/io.github.chidiwilliams.Buzz) or a [Snap](https://snapcraft.io/buzz).
|
||||
|
||||
To install flatpak, run:
|
||||
```shell
|
||||
flatpak install flathub io.github.chidiwilliams.Buzz
|
||||
```
|
||||
|
||||
[](https://flathub.org/en/apps/io.github.chidiwilliams.Buzz)
|
||||
|
||||
To install snap, run:
|
||||
```shell
|
||||
sudo apt-get install libportaudio2 libcanberra-gtk-module libcanberra-gtk3-module
|
||||
sudo snap install buzz
|
||||
```
|
||||
|
||||
[](https://snapcraft.io/buzz)
|
||||
|
||||
### PyPI
|
||||
**PyPI**:
|
||||
|
||||
Install [ffmpeg](https://www.ffmpeg.org/download.html)
|
||||
|
||||
Ensure you use Python 3.12 environment.
|
||||
|
||||
Install Buzz
|
||||
|
||||
```shell
|
||||
pip install buzz-captions
|
||||
python -m buzz
|
||||
```
|
||||
|
||||
**GPU support for PyPI**
|
||||
**macOS**:
|
||||
|
||||
To have GPU support for Nvidia GPUS on Windows, for PyPI installed version ensure, CUDA support for [torch](https://pytorch.org/get-started/locally/)
|
||||
Install with [brew utility](https://brew.sh/)
|
||||
|
||||
```
|
||||
pip3 install -U torch==2.8.0+cu129 torchaudio==2.8.0+cu129 --index-url https://download.pytorch.org/whl/cu129
|
||||
pip3 install nvidia-cublas-cu12==12.9.1.4 nvidia-cuda-cupti-cu12==12.9.79 nvidia-cuda-runtime-cu12==12.9.79 --extra-index-url https://pypi.ngc.nvidia.com
|
||||
```shell
|
||||
brew install --cask buzz
|
||||
```
|
||||
|
||||
### Latest development version
|
||||
Or download the `.dmg` from the [releases page](https://github.com/chidiwilliams/buzz/releases/latest).
|
||||
|
||||
For info on how to get latest development version with latest features and bug fixes see [FAQ](https://chidiwilliams.github.io/buzz/docs/faq#9-where-can-i-get-latest-development-version).
|
||||
**Windows**:
|
||||
|
||||
### Support Buzz
|
||||
Download and run the `.exe` from the [releases page](https://github.com/chidiwilliams/buzz/releases/latest).
|
||||
|
||||
You can help the Buzz by starring 🌟 the repo and sharing it with your friends.
|
||||
App is not signed, you will get a warning when you install it. Select `More info` -> `Run anyway`.
|
||||
|
||||
### Screenshots
|
||||
|
||||
<div style="display: flex; flex-wrap: wrap;">
|
||||
<img alt="File import" src="https://github.com/chidiwilliams/buzz/raw/main/share/screenshots/buzz-1-import.png" style="max-width: 18%; margin-right: 1%;" />
|
||||
<img alt="Main screen" src="https://github.com/chidiwilliams/buzz/raw/main/share/screenshots/buzz-2-main_screen.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="Preferences" src="https://github.com/chidiwilliams/buzz/raw/main/share/screenshots/buzz-3-preferences.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="Model preferences" src="https://github.com/chidiwilliams/buzz/raw/main/share/screenshots/buzz-3.2-model-preferences.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="Transcript" src="https://github.com/chidiwilliams/buzz/raw/main/share/screenshots/buzz-4-transcript.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="Live recording" src="https://github.com/chidiwilliams/buzz/raw/main/share/screenshots/buzz-5-live_recording.png" style="max-width: 18%; margin-right: 1%; height:auto;" />
|
||||
<img alt="Resize" src="https://github.com/chidiwilliams/buzz/raw/main/share/screenshots/buzz-6-resize.png" style="max-width: 18%;" />
|
||||
</div>
|
||||
**Linux**:
|
||||
|
||||
```shell
|
||||
sudo apt-get install libportaudio2 libcanberra-gtk-module libcanberra-gtk3-module
|
||||
sudo snap install buzz
|
||||
sudo snap connect buzz:audio-record
|
||||
sudo snap connect buzz:password-manager-service
|
||||
sudo snap connect buzz:pulseaudio
|
||||
sudo snap connect buzz:removable-media
|
||||
```
|
||||
|
|
|
|||
9
build.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import subprocess
|
||||
|
||||
|
||||
def build(setup_kwargs):
|
||||
subprocess.call(["make", "buzz/whisper_cpp.py"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
build({})
|
||||
|
|
@ -1 +1 @@
|
|||
VERSION = "1.4.4"
|
||||
VERSION = "1.2.0"
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M6.75 1C6.33579 1 6 1.33579 6 1.75V3.50559C5.96824 3.53358 5.93715 3.56276 5.9068 3.59311L1.66416 7.83575C0.883107 8.6168 0.883107 9.88313 1.66416 10.6642L5.19969 14.1997C5.98074 14.9808 7.24707 14.9808 8.02812 14.1997L12.2708 9.95707C13.0518 9.17602 13.0518 7.90969 12.2708 7.12864L8.73522 3.59311C8.39027 3.24816 7.95066 3.05555 7.5 3.0153V1.75C7.5 1.33579 7.16421 1 6.75 1ZM6 5.62123V6.25C6 6.66421 6.33579 7 6.75 7C7.16421 7 7.5 6.66421 7.5 6.25V4.54033C7.56363 4.56467 7.62328 4.60249 7.67456 4.65377L11.2101 8.1893C11.2995 8.27875 11.348 8.39366 11.3555 8.51071H3.11052L6 5.62123ZM6.26035 13.1391L3.132 10.0107H10.0958L6.96746 13.1391C6.77219 13.3343 6.45561 13.3343 6.26035 13.1391Z" fill="#212121"/>
|
||||
<path d="M2 17.5V12.4143L3.5 13.9143V17.5C3.5 18.0523 3.94772 18.5 4.5 18.5H19.5C20.0523 18.5 20.5 18.0523 20.5 17.5V6.5C20.5 5.94771 20.0523 5.5 19.5 5.5H12.0563L10.5563 4H19.5C20.8807 4 22 5.11929 22 6.5V17.5C22 18.8807 20.8807 20 19.5 20H4.5C3.11929 20 2 18.8807 2 17.5Z" fill="#212121"/>
|
||||
<path d="M11 14.375C11 13.8816 11.1541 13.4027 11.3418 12.9938C11.5325 12.5784 11.7798 12.1881 12.0158 11.8595C12.2531 11.5289 12.4888 11.247 12.6647 11.0481C12.7502 10.9515 12.9062 10.7867 12.9642 10.7254L12.9697 10.7197C13.2626 10.4268 13.7374 10.4268 14.0303 10.7197L14.3353 11.0481C14.5112 11.247 14.7469 11.5289 14.9842 11.8595C15.2202 12.1881 15.4675 12.5784 15.6582 12.9938C15.8459 13.4027 16 13.8816 16 14.375C16 15.7654 14.9711 17 13.5 17C12.0289 17 11 15.7654 11 14.375ZM13.7658 12.7343C13.676 12.6092 13.5858 12.4916 13.5 12.3844C13.4142 12.4916 13.324 12.6092 13.2342 12.7343C13.0327 13.015 12.8425 13.32 12.7051 13.6195C12.5647 13.9253 12.5 14.1808 12.5 14.375C12.5 15.0663 12.9809 15.5 13.5 15.5C14.0191 15.5 14.5 15.0663 14.5 14.375C14.5 14.1808 14.4353 13.9253 14.2949 13.6195C14.1575 13.32 13.9673 13.015 13.7658 12.7343Z" fill="#212121"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2 KiB |
|
|
@ -1,5 +0,0 @@
|
|||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M21.7092 2.29502C21.8041 2.3904 21.8757 2.50014 21.9241 2.61722C21.9727 2.73425 21.9996 2.8625 22 2.997L22 3V9C22 9.55228 21.5523 10 21 10C20.4477 10 20 9.55228 20 9V5.41421L14.7071 10.7071C14.3166 11.0976 13.6834 11.0976 13.2929 10.7071C12.9024 10.3166 12.9024 9.68342 13.2929 9.29289L18.5858 4H15C14.4477 4 14 3.55228 14 3C14 2.44772 14.4477 2 15 2H20.9998C21.2749 2 21.5242 2.11106 21.705 2.29078L21.7092 2.29502Z" fill="#000000"/>
|
||||
<path d="M10.7071 14.7071L5.41421 20H9C9.55228 20 10 20.4477 10 21C10 21.5523 9.55228 22 9 22H3.00069L2.997 22C2.74301 21.9992 2.48924 21.9023 2.29502 21.7092L2.29078 21.705C2.19595 21.6096 2.12432 21.4999 2.07588 21.3828C2.02699 21.2649 2 21.1356 2 21V15C2 14.4477 2.44772 14 3 14C3.55228 14 4 14.4477 4 15V18.5858L9.29289 13.2929C9.68342 12.9024 10.3166 12.9024 10.7071 13.2929C11.0976 13.6834 11.0976 14.3166 10.7071 14.7071Z" fill="#000000"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.1 KiB |
|
|
@ -1,2 +0,0 @@
|
|||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg fill="#000000" width="800px" height="800px" viewBox="0 0 14 14" role="img" focusable="false" aria-hidden="true" xmlns="http://www.w3.org/2000/svg"><path d="M 7.5291661,11.795909 C 7.4168129,11.419456 7.3406864,10.225625 7.3406864,9.29222 c 0,-0.11438 -0.029767,-0.221667 -0.081573,-0.314893 0.051933,-0.115773 0.08132,-0.24358 0.08132,-0.378226 l 0,-1.709364 c 0,-0.511733 -0.416226,-0.927959 -0.9279585,-0.927959 l -0.8772919,0 C 5.527203,5.856265 5.52163,5.751005 5.518336,5.648406 5.514666,5.556066 5.513396,5.470313 5.513016,5.385826 5.511876,5.296776 5.5132694,5.224073 5.517196,5.160866 5.524666,5.024193 5.541009,4.891827 5.565076,4.773647 5.591043,4.646981 5.619669,4.564774 5.630689,4.535134 c 0.0019,-0.0052 0.0038,-0.01013 0.00557,-0.01533 0.00709,-0.02039 0.0133,-0.03559 0.017227,-0.04446 C 6.0127121,3.789698 5.750766,2.938499 5.0665137,2.5737 4.8642273,2.466034 4.6367344,2.409034 4.4084814,2.408147 4.1801018,2.409034 3.9526089,2.466037 3.7504492,2.5737 3.066197,2.938499 2.8042508,3.789698 3.1634768,4.475344 c 0.00393,0.0087 0.01026,0.02394 0.017227,0.04446 0.00177,0.0052 0.00367,0.01013 0.00557,0.01533 0.01102,0.02951 0.039647,0.111847 0.065613,0.238513 0.024067,0.11818 0.040533,0.250546 0.04788,0.387219 0.00393,0.06321 0.00532,0.135914 0.00418,0.22496 -5.066e-4,0.08449 -0.00165,0.17024 -0.00532,0.26258 -0.00329,0.102599 -0.00887,0.207859 -0.016847,0.313372 l -0.8772919,0 c -0.5117324,0 -0.9279584,0.416226 -0.9279584,0.927959 l 0,1.709364 c 0,0.134646 0.029387,0.262453 0.08132,0.378226 -0.051807,0.09323 -0.081573,0.200513 -0.081573,0.314893 0,0.933278 -0.076126,2.127236 -0.1884796,2.503689 C 1.0571435,11.985782 1.0131902,12.254315 1.0562568,12.453434 1.1748167,13 1.7477291,13 1.9359554,13 c 0.437506,0 1.226258,-0.07676 1.2595712,-0.08005 0.05092,-0.0051 0.1001932,-0.01596 0.1468065,-0.03179 0.049907,0.01241 0.1018398,0.01913 0.1546597,0.01925 l 0.9114918,0.0044 0.9114918,-0.0044 c 0.05282,-1.27e-4 0.1047532,-0.007 0.1546598,-0.01925 0.046613,0.01583 0.095886,0.02673 0.1468064,0.03179 C 5.6547556,12.92315 6.4436346,13 6.8810138,13 c 0.1882264,0 0.7612654,0 0.8796986,-0.546566 0.043067,-0.199119 -7.6e-4,-0.467652 -0.2315463,-0.657525 z m -1.833117,0.502486 -0.3480794,-1.518478 -0.1741664,1.503658 -1.6846638,-7.6e-4 -0.3680927,-0.885399 0,0.900979 c 0,0 -1.7672504,0.173279 -1.3861111,0 0.3811394,-0.173154 0.3811394,-2.980082 0.3811394,-2.980082 l 2.2924095,0 2.2924095,0 c 0,0 0,2.806928 0.3811394,2.980082 0.381266,0.173279 -1.3859844,0 -1.3859844,0 z M 10.219055,1 7.3387864,1 5.8932688,5.377719 l 0.9449318,0 c 0.3536527,0 0.6674055,0.17138 0.8650052,0.434593 l 0.04864,-0.18392 0.9107318,-2.702555 0.2962729,-0.0016 0.9543051,2.889769 -2.2085564,0 C 7.839499,5.994632 7.9204389,6.217692 7.9204389,6.459878 l 0,1.257038 2.3962751,0 0.423193,1.60917 2.218563,0 L 10.219055,1 Z"/></svg>
|
||||
|
Before Width: | Height: | Size: 2.9 KiB |
|
|
@ -1,7 +0,0 @@
|
|||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="-0.5 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8.93994 9.39998V5.48999C8.93994 5.20999 9.15994 4.98999 9.43994 4.98999H20.9999C21.2799 4.98999 21.4999 5.20999 21.4999 5.48999V13.09C21.4999 13.37 21.2799 13.59 20.9999 13.59L17.0599 13.6" stroke="#0F0F0F" stroke-miterlimit="10" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M17.7301 8.72998L16.4301 10.03" stroke="#0F0F0F" stroke-miterlimit="10" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M3 11.4H14.56C14.84 11.4 15.06 11.62 15.06 11.9V19.51C15.06 19.79 14.84 20.01 14.56 20.01H3C2.72 20.01 2.5 19.79 2.5 19.51V11.9C2.5 11.63 2.72 11.4 3 11.4Z" stroke="#0F0F0F" stroke-miterlimit="10" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M19.32 10.03V7.64001C19.32 7.36001 19.1 7.14001 18.82 7.14001H16.42" stroke="#0F0F0F" stroke-miterlimit="10" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.1 KiB |
|
|
@ -1,14 +0,0 @@
|
|||
<?xml version="1.0" encoding="iso-8859-1"?>
|
||||
<svg height="800px" width="800px" version="1.1" id="Capa_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
|
||||
viewBox="0 0 493.347 493.347" xml:space="preserve">
|
||||
<g>
|
||||
<path style="fill:#010002;" d="M191.936,385.946c-14.452,0-29.029-1.36-43.319-4.04l-5.299-0.996l-66.745,37.15v-63.207
|
||||
l-6.629-4.427C25.496,320.716,0,277.045,0,230.617c0-85.648,86.102-155.33,191.936-155.33c17.077,0,33.623,1.838,49.394,5.239
|
||||
c-50.486,27.298-84.008,74.801-84.008,128.765c0,72.969,61.25,134.147,142.942,149.464
|
||||
C269.41,375.892,232.099,385.946,191.936,385.946z"/>
|
||||
<path style="fill:#010002;" d="M437.777,304.278l-6.629,4.427v48.075l-50.933-28.343l-0.125,0.024l-5.167,0.967
|
||||
c-11.444,2.142-23.104,3.228-34.673,3.228c-1.241,0-2.47-0.054-3.705-0.078c-82.707-1.599-149.387-56.268-149.387-123.287
|
||||
c0-52.109,40.324-96.741,97.129-114.791c14.47-4.594,30.001-7.471,46.219-8.3c3.228-0.167,6.468-0.274,9.75-0.274
|
||||
c84.413,0,153.092,55.343,153.092,123.365C493.347,246.053,473.089,280.679,437.777,304.278z"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1 KiB |
|
|
@ -1 +0,0 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" height="48" viewBox="0 -960 960 960" width="48"><path d="M160-200v-60h640v60H160Zm320-136L280-536l42-42 128 128v-310h60v310l128-128 42 42-200 200Z" transform="rotate(180 480 -480)"/></svg>
|
||||
|
Before Width: | Height: | Size: 229 B |
|
|
@ -1,19 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 -0.5 21 21" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
|
||||
<title>url [#1423]</title>
|
||||
<desc>Created with Sketch.</desc>
|
||||
<defs>
|
||||
|
||||
</defs>
|
||||
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="Dribbble-Light-Preview" transform="translate(-339.000000, -600.000000)" fill="#000000">
|
||||
<g id="icons" transform="translate(56.000000, 160.000000)">
|
||||
<path d="M286.388001,443.226668 C288.054626,441.639407 290.765027,441.639407 292.431651,443.226668 L293.942296,444.665378 L295.452942,443.226668 L293.942296,441.787958 C291.439155,439.404014 287.380498,439.404014 284.877356,441.787958 C282.374215,444.171902 282.374215,448.03729 284.877356,450.421235 L286.388001,451.859945 L287.898647,450.421235 L286.388001,448.982525 C284.721377,447.395264 284.721377,444.813929 286.388001,443.226668 L286.388001,443.226668 Z M302.122644,449.578765 L300.611999,448.139038 L299.101353,449.578765 L300.611999,451.017475 C302.277554,452.603719 302.277554,455.186071 300.611999,456.773332 C298.945374,458.359576 296.233905,458.359576 294.568349,456.773332 L293.057704,455.333605 L291.54599,456.773332 L293.057704,458.212042 C295.560845,460.595986 299.619502,460.595986 302.122644,458.212042 C304.625785,455.828098 304.625785,451.96271 302.122644,449.578765 L302.122644,449.578765 Z M288.653969,443.946023 L299.856676,454.61425 L298.344962,456.053977 L287.143324,445.384733 L288.653969,443.946023 Z" id="url-[#1423]">
|
||||
|
||||
</path>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.7 KiB |
41
buzz/buzz.py
|
|
@ -4,32 +4,18 @@ import multiprocessing
|
|||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
# Set up CUDA library paths before any torch imports
|
||||
# This must happen before platformdirs or any other imports that might indirectly load torch
|
||||
import buzz.cuda_setup # noqa: F401
|
||||
|
||||
from platformdirs import user_log_dir, user_cache_dir, user_data_dir
|
||||
|
||||
# Will download all Huggingface data to the app cache directory
|
||||
os.environ.setdefault("HF_HOME", user_cache_dir("Buzz"))
|
||||
|
||||
from buzz.assets import APP_BASE_DIR
|
||||
|
||||
# Check for segfaults if not running in frozen mode
|
||||
# Note: On Windows, faulthandler can print "Windows fatal exception" messages
|
||||
# for non-fatal RPC errors (0x800706be) during multiprocessing operations.
|
||||
# These are usually harmless but noisy, so we disable faulthandler on Windows.
|
||||
if getattr(sys, "frozen", False) is False and platform.system() != "Windows":
|
||||
if getattr(sys, "frozen", False) is False:
|
||||
faulthandler.enable()
|
||||
|
||||
# Sets stdout/stderr to no-op TextIO when None (run as Windows GUI with --noconsole).
|
||||
# stdout fix: torch.hub uses sys.stdout.write() for download progress and crashes if None.
|
||||
# stderr fix: Resolves https://github.com/chidiwilliams/buzz/issues/221
|
||||
if sys.stdout is None:
|
||||
sys.stdout = TextIO()
|
||||
# Sets stderr to no-op TextIO when None (run as Windows GUI).
|
||||
# Resolves https://github.com/chidiwilliams/buzz/issues/221
|
||||
if sys.stderr is None:
|
||||
sys.stderr = TextIO()
|
||||
|
||||
|
|
@ -40,14 +26,7 @@ os.environ["PATH"] += os.pathsep + APP_BASE_DIR
|
|||
# Add the app directory to the DLL list: https://stackoverflow.com/a/64303856
|
||||
if platform.system() == "Windows":
|
||||
os.add_dll_directory(APP_BASE_DIR)
|
||||
|
||||
dll_backup_dir = os.path.join(APP_BASE_DIR, "dll_backup")
|
||||
if os.path.isdir(dll_backup_dir):
|
||||
os.add_dll_directory(dll_backup_dir)
|
||||
|
||||
onnx_dll_dir = os.path.join(APP_BASE_DIR, "onnxruntime", "capi")
|
||||
if os.path.isdir(onnx_dll_dir):
|
||||
os.add_dll_directory(onnx_dll_dir)
|
||||
os.add_dll_directory(os.path.join(APP_BASE_DIR, "dll_backup"))
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -70,18 +49,6 @@ def main():
|
|||
format=log_format,
|
||||
)
|
||||
|
||||
# Silence noisy third-party library loggers
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
logging.getLogger("graphviz").setLevel(logging.WARNING)
|
||||
logging.getLogger("nemo_logger").setLevel(logging.ERROR)
|
||||
logging.getLogger("nemo_logging").setLevel(logging.ERROR)
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("torio._extension.utils").setLevel(logging.WARNING)
|
||||
logging.getLogger("export_config_manager").setLevel(logging.WARNING)
|
||||
logging.getLogger("training_telemetry_provider").setLevel(logging.ERROR)
|
||||
logging.getLogger("default_recorder").setLevel(logging.WARNING)
|
||||
logging.getLogger("config").setLevel(logging.WARNING)
|
||||
|
||||
if getattr(sys, "frozen", False) is False:
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setLevel(logging.DEBUG)
|
||||
|
|
|
|||
|
|
@ -100,10 +100,7 @@ def parse(app: Application, parser: QCommandLineParser):
|
|||
["p", "prompt"], "Initial prompt.", "prompt", ""
|
||||
)
|
||||
word_timestamp_option = QCommandLineOption(
|
||||
["w", "word-timestamps"], "Generate word-level timestamps."
|
||||
)
|
||||
extract_speech_option = QCommandLineOption(
|
||||
["e", "extract-speech"], "Extract speech from audio before transcribing."
|
||||
["wt", "word-timestamps"], "Generate word-level timestamps."
|
||||
)
|
||||
open_ai_access_token_option = QCommandLineOption(
|
||||
"openai-token",
|
||||
|
|
@ -127,7 +124,6 @@ def parse(app: Application, parser: QCommandLineParser):
|
|||
language_option,
|
||||
initial_prompt_option,
|
||||
word_timestamp_option,
|
||||
extract_speech_option,
|
||||
open_ai_access_token_option,
|
||||
output_directory_option,
|
||||
srt_option,
|
||||
|
|
@ -182,7 +178,6 @@ def parse(app: Application, parser: QCommandLineParser):
|
|||
initial_prompt = parser.value(initial_prompt_option)
|
||||
|
||||
word_timestamps = parser.isSet(word_timestamp_option)
|
||||
extract_speech = parser.isSet(extract_speech_option)
|
||||
|
||||
output_formats: typing.Set[OutputFormat] = set()
|
||||
if parser.isSet(srt_option):
|
||||
|
|
@ -210,7 +205,6 @@ def parse(app: Application, parser: QCommandLineParser):
|
|||
language=language,
|
||||
initial_prompt=initial_prompt,
|
||||
word_level_timings=word_timestamps,
|
||||
extract_speech=extract_speech,
|
||||
openai_access_token=openai_access_token,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,130 +0,0 @@
|
|||
"""
|
||||
CUDA library path setup for nvidia packages installed via pip.
|
||||
|
||||
This module must be imported BEFORE any torch or CUDA-dependent libraries are imported.
|
||||
It handles locating and loading CUDA libraries (cuDNN, cuBLAS, etc.) from the nvidia
|
||||
pip packages.
|
||||
|
||||
On Windows: Uses os.add_dll_directory() to add library paths
|
||||
On Linux: Uses ctypes to preload libraries (LD_LIBRARY_PATH is read at process start)
|
||||
On macOS: No action needed (CUDA not supported)
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_nvidia_package_lib_dirs() -> list[Path]:
|
||||
"""Find all nvidia package library directories in site-packages."""
|
||||
lib_dirs = []
|
||||
|
||||
# Find site-packages directories
|
||||
site_packages_dirs = []
|
||||
for path in sys.path:
|
||||
if "site-packages" in path:
|
||||
site_packages_dirs.append(Path(path))
|
||||
|
||||
# Also check relative to the current module for frozen apps
|
||||
if getattr(sys, "frozen", False):
|
||||
# For frozen apps, check the _internal directory
|
||||
frozen_lib_dir = Path(sys._MEIPASS) if hasattr(sys, "_MEIPASS") else Path(sys.executable).parent
|
||||
nvidia_dir = frozen_lib_dir / "nvidia"
|
||||
if nvidia_dir.exists():
|
||||
for pkg_dir in nvidia_dir.iterdir():
|
||||
if pkg_dir.is_dir():
|
||||
lib_subdir = pkg_dir / "lib"
|
||||
if lib_subdir.exists():
|
||||
lib_dirs.append(lib_subdir)
|
||||
# Some packages have bin directory on Windows
|
||||
bin_subdir = pkg_dir / "bin"
|
||||
if bin_subdir.exists():
|
||||
lib_dirs.append(bin_subdir)
|
||||
|
||||
# Check each site-packages for nvidia packages
|
||||
for sp_dir in site_packages_dirs:
|
||||
nvidia_dir = sp_dir / "nvidia"
|
||||
if nvidia_dir.exists():
|
||||
for pkg_dir in nvidia_dir.iterdir():
|
||||
if pkg_dir.is_dir():
|
||||
lib_subdir = pkg_dir / "lib"
|
||||
if lib_subdir.exists():
|
||||
lib_dirs.append(lib_subdir)
|
||||
# Some packages have bin directory on Windows
|
||||
bin_subdir = pkg_dir / "bin"
|
||||
if bin_subdir.exists():
|
||||
lib_dirs.append(bin_subdir)
|
||||
|
||||
return lib_dirs
|
||||
|
||||
|
||||
def _setup_windows_dll_directories():
|
||||
"""Add nvidia library directories to Windows DLL search path."""
|
||||
lib_dirs = _get_nvidia_package_lib_dirs()
|
||||
for lib_dir in lib_dirs:
|
||||
try:
|
||||
os.add_dll_directory(str(lib_dir))
|
||||
except (OSError, AttributeError) as e:
|
||||
pass
|
||||
|
||||
|
||||
def _preload_linux_libraries():
|
||||
"""Preload CUDA libraries on Linux using ctypes.
|
||||
|
||||
On Linux, LD_LIBRARY_PATH is only read at process start, so we need to
|
||||
manually load the libraries using ctypes before torch tries to load them.
|
||||
"""
|
||||
lib_dirs = _get_nvidia_package_lib_dirs()
|
||||
|
||||
# Libraries to skip - NVBLAS requires special configuration and causes issues
|
||||
skip_patterns = ["libnvblas"]
|
||||
|
||||
loaded_libs = set()
|
||||
|
||||
for lib_dir in lib_dirs:
|
||||
if not lib_dir.exists():
|
||||
continue
|
||||
|
||||
# Find all .so files in the directory
|
||||
for lib_file in sorted(lib_dir.glob("*.so*")):
|
||||
if lib_file.name in loaded_libs:
|
||||
continue
|
||||
if lib_file.is_symlink() and not lib_file.exists():
|
||||
continue
|
||||
|
||||
# Skip problematic libraries
|
||||
if any(pattern in lib_file.name for pattern in skip_patterns):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Use RTLD_GLOBAL so symbols are available to other libraries
|
||||
ctypes.CDLL(str(lib_file), mode=ctypes.RTLD_GLOBAL)
|
||||
loaded_libs.add(lib_file.name)
|
||||
except OSError as e:
|
||||
# Some libraries may have missing dependencies, that's ok
|
||||
pass
|
||||
|
||||
|
||||
def setup_cuda_libraries():
|
||||
"""Set up CUDA library paths for the current platform.
|
||||
|
||||
This function should be called as early as possible, before any torch
|
||||
or CUDA-dependent libraries are imported.
|
||||
"""
|
||||
system = platform.system()
|
||||
|
||||
if system == "Windows":
|
||||
_setup_windows_dll_directories()
|
||||
elif system == "Linux":
|
||||
_preload_linux_libraries()
|
||||
# macOS doesn't have CUDA support, so nothing to do
|
||||
|
||||
|
||||
# Auto-run setup when this module is imported
|
||||
setup_cuda_libraries()
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -32,11 +31,7 @@ class TranscriptionDAO(DAO[Transcription]):
|
|||
time_queued,
|
||||
url,
|
||||
whisper_model_size,
|
||||
hugging_face_model_id,
|
||||
word_level_timings,
|
||||
extract_speech,
|
||||
name,
|
||||
notes
|
||||
hugging_face_model_id
|
||||
) VALUES (
|
||||
:id,
|
||||
:export_formats,
|
||||
|
|
@ -50,13 +45,9 @@ class TranscriptionDAO(DAO[Transcription]):
|
|||
:time_queued,
|
||||
:url,
|
||||
:whisper_model_size,
|
||||
:hugging_face_model_id,
|
||||
:word_level_timings,
|
||||
:extract_speech,
|
||||
:name,
|
||||
:notes
|
||||
:hugging_face_model_id
|
||||
)
|
||||
"""
|
||||
"""
|
||||
)
|
||||
query.bindValue(":id", str(task.uid))
|
||||
query.bindValue(
|
||||
|
|
@ -91,84 +82,9 @@ class TranscriptionDAO(DAO[Transcription]):
|
|||
if task.transcription_options.model.hugging_face_model_id
|
||||
else None,
|
||||
)
|
||||
query.bindValue(
|
||||
":word_level_timings",
|
||||
task.transcription_options.word_level_timings
|
||||
)
|
||||
query.bindValue(
|
||||
":extract_speech",
|
||||
task.transcription_options.extract_speech
|
||||
)
|
||||
query.bindValue(":name", None) # name is not available in FileTranscriptionTask
|
||||
query.bindValue(":notes", None) # notes is not available in FileTranscriptionTask
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def copy_transcription(self, id: UUID) -> UUID:
|
||||
query = self._create_query()
|
||||
query.prepare("SELECT * FROM transcription WHERE id = :id")
|
||||
query.bindValue(":id", str(id))
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
if not query.next():
|
||||
raise Exception("Transcription not found")
|
||||
|
||||
transcription_data = {field.name: query.value(field.name) for field in
|
||||
self.entity.__dataclass_fields__.values()}
|
||||
|
||||
new_id = uuid.uuid4()
|
||||
transcription_data["id"] = str(new_id)
|
||||
transcription_data["time_queued"] = datetime.now().isoformat()
|
||||
transcription_data["status"] = FileTranscriptionTask.Status.QUEUED.value
|
||||
|
||||
query.prepare(
|
||||
"""
|
||||
INSERT INTO transcription (
|
||||
id,
|
||||
export_formats,
|
||||
file,
|
||||
output_folder,
|
||||
language,
|
||||
model_type,
|
||||
source,
|
||||
status,
|
||||
task,
|
||||
time_queued,
|
||||
url,
|
||||
whisper_model_size,
|
||||
hugging_face_model_id,
|
||||
word_level_timings,
|
||||
extract_speech,
|
||||
name,
|
||||
notes
|
||||
) VALUES (
|
||||
:id,
|
||||
:export_formats,
|
||||
:file,
|
||||
:output_folder,
|
||||
:language,
|
||||
:model_type,
|
||||
:source,
|
||||
:status,
|
||||
:task,
|
||||
:time_queued,
|
||||
:url,
|
||||
:whisper_model_size,
|
||||
:hugging_face_model_id,
|
||||
:word_level_timings,
|
||||
:extract_speech,
|
||||
:name,
|
||||
:notes
|
||||
)
|
||||
"""
|
||||
)
|
||||
for key, value in transcription_data.items():
|
||||
query.bindValue(f":{key}", value)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
return new_id
|
||||
|
||||
def update_transcription_as_started(self, id: UUID):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
|
|
@ -249,72 +165,3 @@ class TranscriptionDAO(DAO[Transcription]):
|
|||
query.bindValue(":time_ended", datetime.now().isoformat())
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def update_transcription_file_and_name(self, id: UUID, file_path: str, name: str | None = None):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET file = :file, name = COALESCE(:name, name)
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":file", file_path)
|
||||
query.bindValue(":name", name)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
|
||||
def update_transcription_name(self, id: UUID, name: str):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET name = :name
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":name", name)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
if query.numRowsAffected() == 0:
|
||||
raise Exception("Transcription not found")
|
||||
|
||||
def update_transcription_notes(self, id: UUID, notes: str):
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET notes = :notes
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":notes", notes)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
if query.numRowsAffected() == 0:
|
||||
raise Exception("Transcription not found")
|
||||
|
||||
def reset_transcription_for_restart(self, id: UUID):
|
||||
"""Reset a transcription to queued status for restart"""
|
||||
query = self._create_query()
|
||||
query.prepare(
|
||||
"""
|
||||
UPDATE transcription
|
||||
SET status = :status, progress = :progress, time_started = NULL, time_ended = NULL, error_message = NULL
|
||||
WHERE id = :id
|
||||
"""
|
||||
)
|
||||
|
||||
query.bindValue(":id", str(id))
|
||||
query.bindValue(":status", FileTranscriptionTask.Status.QUEUED.value)
|
||||
query.bindValue(":progress", 0.0)
|
||||
if not query.exec():
|
||||
raise Exception(query.lastError().text())
|
||||
if query.numRowsAffected() == 0:
|
||||
raise Exception("Transcription not found")
|
||||
|
|
|
|||
|
|
@ -25,14 +25,11 @@ def setup_test_db() -> QSqlDatabase:
|
|||
|
||||
def _setup_db(path: str) -> QSqlDatabase:
|
||||
# Run migrations
|
||||
db = sqlite3.connect(path, isolation_level=None, timeout=10.0)
|
||||
try:
|
||||
run_sqlite_migrations(db)
|
||||
copy_transcriptions_from_json_to_sqlite(db)
|
||||
mark_in_progress_and_queued_transcriptions_as_canceled(db)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
db = sqlite3.connect(path)
|
||||
run_sqlite_migrations(db)
|
||||
copy_transcriptions_from_json_to_sqlite(db)
|
||||
mark_in_progress_and_queued_transcriptions_as_canceled(db)
|
||||
db.close()
|
||||
|
||||
db = QSqlDatabase.addDatabase("QSQLITE")
|
||||
db.setDatabaseName(path)
|
||||
|
|
@ -41,12 +38,3 @@ def _setup_db(path: str) -> QSqlDatabase:
|
|||
db.exec('PRAGMA foreign_keys = ON')
|
||||
logging.debug("Database connection opened: %s", db.databaseName())
|
||||
return db
|
||||
|
||||
|
||||
def close_app_db():
|
||||
db = QSqlDatabase.database()
|
||||
if not db.isValid():
|
||||
return
|
||||
|
||||
if db.isOpen():
|
||||
db.close()
|
||||
|
|
@ -16,8 +16,6 @@ class Transcription(Entity):
|
|||
model_type: str = ModelType.WHISPER.value
|
||||
whisper_model_size: str | None = None
|
||||
hugging_face_model_id: str | None = None
|
||||
word_level_timings: str | None = None
|
||||
extract_speech: str | None = None
|
||||
language: str | None = None
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
error_message: str | None = None
|
||||
|
|
@ -30,8 +28,6 @@ class Transcription(Entity):
|
|||
output_folder: str | None = None
|
||||
source: str | None = None
|
||||
url: str | None = None
|
||||
name: str | None = None
|
||||
notes: str | None = None
|
||||
|
||||
@property
|
||||
def id_as_uuid(self):
|
||||
|
|
|
|||
|
|
@ -69,8 +69,7 @@ class DBMigrator:
|
|||
msg_argv += (args,)
|
||||
else:
|
||||
args = []
|
||||
# Uncomment this to get debugging information
|
||||
# logging.info(msg_tmpl, *msg_argv)
|
||||
logging.info(msg_tmpl, *msg_argv)
|
||||
self.db.execute(sql, args)
|
||||
self.n_changes += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -19,9 +19,6 @@ class TranscriptionService:
|
|||
def create_transcription(self, task):
|
||||
self.transcription_dao.create_transcription(task)
|
||||
|
||||
def copy_transcription(self, id: UUID) -> UUID:
|
||||
return self.transcription_dao.copy_transcription(id)
|
||||
|
||||
def update_transcription_as_started(self, id: UUID):
|
||||
self.transcription_dao.update_transcription_as_started(id)
|
||||
|
||||
|
|
@ -47,18 +44,6 @@ class TranscriptionService:
|
|||
)
|
||||
)
|
||||
|
||||
def update_transcription_file_and_name(self, id: UUID, file_path: str, name: str | None = None):
|
||||
self.transcription_dao.update_transcription_file_and_name(id, file_path, name)
|
||||
|
||||
def update_transcription_name(self, id: UUID, name: str):
|
||||
self.transcription_dao.update_transcription_name(id, name)
|
||||
|
||||
def update_transcription_notes(self, id: UUID, notes: str):
|
||||
self.transcription_dao.update_transcription_notes(id, notes)
|
||||
|
||||
def reset_transcription_for_restart(self, id: UUID):
|
||||
self.transcription_dao.reset_transcription_for_restart(id)
|
||||
|
||||
def replace_transcription_segments(self, id: UUID, segments: List[Segment]):
|
||||
self.transcription_segment_dao.delete_segments(id)
|
||||
for segment in segments:
|
||||
|
|
|
|||
|
|
@ -1,65 +1,18 @@
|
|||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import queue
|
||||
import ssl
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List, Set
|
||||
from uuid import UUID
|
||||
|
||||
# Fix SSL certificate verification for bundled applications (macOS, Windows)
|
||||
# This must be done before importing demucs which uses torch.hub with urllib
|
||||
try:
|
||||
import certifi
|
||||
os.environ.setdefault('REQUESTS_CA_BUNDLE', certifi.where())
|
||||
os.environ.setdefault('SSL_CERT_FILE', certifi.where())
|
||||
os.environ.setdefault('SSL_CERT_DIR', os.path.dirname(certifi.where()))
|
||||
# Also update the default SSL context for urllib
|
||||
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
|
||||
except ImportError:
|
||||
pass
|
||||
from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot
|
||||
|
||||
from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot, Qt
|
||||
|
||||
# Patch subprocess for demucs to prevent console windows on Windows
|
||||
if sys.platform == "win32":
|
||||
import subprocess
|
||||
_original_run = subprocess.run
|
||||
_original_check_output = subprocess.check_output
|
||||
|
||||
def _patched_run(*args, **kwargs):
|
||||
if 'startupinfo' not in kwargs:
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
kwargs['startupinfo'] = si
|
||||
if 'creationflags' not in kwargs:
|
||||
kwargs['creationflags'] = subprocess.CREATE_NO_WINDOW
|
||||
return _original_run(*args, **kwargs)
|
||||
|
||||
def _patched_check_output(*args, **kwargs):
|
||||
if 'startupinfo' not in kwargs:
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
kwargs['startupinfo'] = si
|
||||
if 'creationflags' not in kwargs:
|
||||
kwargs['creationflags'] = subprocess.CREATE_NO_WINDOW
|
||||
return _original_check_output(*args, **kwargs)
|
||||
|
||||
subprocess.run = _patched_run
|
||||
subprocess.check_output = _patched_check_output
|
||||
|
||||
from demucs import api as demucsApi
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.transcriber.file_transcriber import FileTranscriber
|
||||
from buzz.transcriber.openai_whisper_api_file_transcriber import (
|
||||
OpenAIWhisperAPIFileTranscriber,
|
||||
)
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment
|
||||
from buzz.transcriber.whisper_cpp_file_transcriber import WhisperCppFileTranscriber
|
||||
from buzz.transcriber.whisper_file_transcriber import WhisperFileTranscriber
|
||||
|
||||
|
||||
|
|
@ -76,38 +29,22 @@ class FileTranscriberQueueWorker(QObject):
|
|||
task_error = pyqtSignal(FileTranscriptionTask, str)
|
||||
|
||||
completed = pyqtSignal()
|
||||
trigger_run = pyqtSignal()
|
||||
|
||||
def __init__(self, parent: Optional[QObject] = None):
|
||||
super().__init__(parent)
|
||||
self.tasks_queue = queue.Queue()
|
||||
self.canceled_tasks: Set[UUID] = set()
|
||||
self.current_transcriber = None
|
||||
self.speech_path = None
|
||||
self.is_running = False
|
||||
# Use QueuedConnection to ensure run() is called in the correct thread context
|
||||
# and doesn't block signal handlers
|
||||
self.trigger_run.connect(self.run, Qt.ConnectionType.QueuedConnection)
|
||||
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
logging.debug("Waiting for next transcription task")
|
||||
|
||||
# Clean up of previous run.
|
||||
if self.current_transcriber is not None:
|
||||
self.current_transcriber.stop()
|
||||
self.current_transcriber = None
|
||||
|
||||
# Get next non-canceled task from queue
|
||||
while True:
|
||||
self.current_task: Optional[FileTranscriptionTask] = self.tasks_queue.get()
|
||||
|
||||
# Stop listening when a "None" task is received
|
||||
if self.current_task is None:
|
||||
self.is_running = False
|
||||
self.completed.emit()
|
||||
return
|
||||
|
||||
|
|
@ -116,66 +53,17 @@ class FileTranscriberQueueWorker(QObject):
|
|||
|
||||
break
|
||||
|
||||
# Set is_running AFTER we have a valid task to process
|
||||
self.is_running = True
|
||||
|
||||
if self.current_task.transcription_options.extract_speech:
|
||||
logging.debug("Will extract speech")
|
||||
|
||||
def separator_progress_callback(progress):
|
||||
self.task_progress.emit(self.current_task, int(progress["segment_offset"] * 100) / int(progress["audio_length"] * 100))
|
||||
|
||||
separator = None
|
||||
separated = None
|
||||
try:
|
||||
# Force CPU if specified, otherwise use CUDA if available
|
||||
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false").lower() == "true"
|
||||
if force_cpu:
|
||||
device = "cpu"
|
||||
else:
|
||||
import torch
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
separator = demucsApi.Separator(
|
||||
device=device,
|
||||
progress=True,
|
||||
callback=separator_progress_callback,
|
||||
)
|
||||
_origin, separated = separator.separate_audio_file(Path(self.current_task.file_path))
|
||||
|
||||
task_file_path = Path(self.current_task.file_path)
|
||||
self.speech_path = task_file_path.with_name(f"{task_file_path.stem}_speech.mp3")
|
||||
demucsApi.save_audio(separated["vocals"], self.speech_path, separator.samplerate)
|
||||
|
||||
self.current_task.file_path = str(self.speech_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during speech extraction: {e}", exc_info=True)
|
||||
self.task_error.emit(
|
||||
self.current_task,
|
||||
_("Speech extraction failed! Check your internet connection — a model may need to be downloaded."),
|
||||
)
|
||||
self.is_running = False
|
||||
return
|
||||
finally:
|
||||
# Release memory used by speech extractor
|
||||
del separator, separated
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logging.debug("Starting next transcription task")
|
||||
self.task_progress.emit(self.current_task, 0)
|
||||
|
||||
model_type = self.current_task.transcription_options.model.model_type
|
||||
if model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
if model_type == ModelType.WHISPER_CPP:
|
||||
self.current_transcriber = WhisperCppFileTranscriber(task=self.current_task)
|
||||
elif model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
self.current_transcriber = OpenAIWhisperAPIFileTranscriber(
|
||||
task=self.current_task
|
||||
)
|
||||
elif (
|
||||
model_type == ModelType.WHISPER_CPP
|
||||
or model_type == ModelType.HUGGING_FACE
|
||||
model_type == ModelType.HUGGING_FACE
|
||||
or model_type == ModelType.WHISPER
|
||||
or model_type == ModelType.FASTER_WHISPER
|
||||
):
|
||||
|
|
@ -206,53 +94,29 @@ class FileTranscriberQueueWorker(QObject):
|
|||
self.current_transcriber.completed.connect(self.on_task_completed)
|
||||
|
||||
# Wait for next item on the queue
|
||||
self.current_transcriber.error.connect(lambda: self._on_task_finished())
|
||||
self.current_transcriber.completed.connect(lambda: self._on_task_finished())
|
||||
self.current_transcriber.error.connect(self.run)
|
||||
self.current_transcriber.completed.connect(self.run)
|
||||
|
||||
self.task_started.emit(self.current_task)
|
||||
self.current_transcriber_thread.start()
|
||||
|
||||
def _on_task_finished(self):
|
||||
"""Called when a task completes or errors, resets state and triggers next run"""
|
||||
self.is_running = False
|
||||
# Use signal to avoid blocking in signal handler context
|
||||
self.trigger_run.emit()
|
||||
|
||||
def add_task(self, task: FileTranscriptionTask):
|
||||
# Remove from canceled tasks if it was previously canceled (for restart functionality)
|
||||
if task.uid in self.canceled_tasks:
|
||||
self.canceled_tasks.remove(task.uid)
|
||||
|
||||
self.tasks_queue.put(task)
|
||||
# If the worker is not currently running, trigger it to start processing
|
||||
# Use signal to avoid blocking the main thread
|
||||
if not self.is_running:
|
||||
self.trigger_run.emit()
|
||||
|
||||
def cancel_task(self, task_id: UUID):
|
||||
self.canceled_tasks.add(task_id)
|
||||
|
||||
if self.current_task is not None and self.current_task.uid == task_id:
|
||||
if self.current_task.uid == task_id:
|
||||
if self.current_transcriber is not None:
|
||||
self.current_transcriber.stop()
|
||||
|
||||
if self.current_transcriber_thread is not None:
|
||||
if not self.current_transcriber_thread.wait(5000):
|
||||
logging.warning("Transcriber thread did not terminate gracefully")
|
||||
self.current_transcriber_thread.terminate()
|
||||
|
||||
def on_task_error(self, error: str):
|
||||
if (
|
||||
self.current_task is not None
|
||||
and self.current_task.uid not in self.canceled_tasks
|
||||
):
|
||||
# Check if the error indicates cancellation
|
||||
if "canceled" in error.lower() or "cancelled" in error.lower():
|
||||
self.current_task.status = FileTranscriptionTask.Status.CANCELED
|
||||
self.current_task.error = error
|
||||
else:
|
||||
self.current_task.status = FileTranscriptionTask.Status.FAILED
|
||||
self.current_task.error = error
|
||||
self.current_task.status = FileTranscriptionTask.Status.FAILED
|
||||
self.current_task.error = error
|
||||
self.task_error.emit(self.current_task, error)
|
||||
|
||||
@pyqtSlot(tuple)
|
||||
|
|
@ -269,13 +133,6 @@ class FileTranscriberQueueWorker(QObject):
|
|||
if self.current_task is not None:
|
||||
self.task_completed.emit(self.current_task, segments)
|
||||
|
||||
if self.speech_path is not None:
|
||||
try:
|
||||
Path(self.speech_path).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
self.speech_path = None
|
||||
|
||||
def stop(self):
|
||||
self.tasks_queue.put(None)
|
||||
if self.current_transcriber is not None:
|
||||
|
|
|
|||
|
|
@ -5,19 +5,19 @@ import gettext
|
|||
from PyQt6.QtCore import QLocale
|
||||
|
||||
from buzz.assets import get_path
|
||||
from buzz.settings.settings import APP_NAME, Settings
|
||||
from buzz.settings.settings import APP_NAME
|
||||
|
||||
locale_dir = get_path("locale")
|
||||
gettext.bindtextdomain("buzz", locale_dir)
|
||||
|
||||
settings = Settings()
|
||||
custom_locale = os.getenv("BUZZ_LOCALE")
|
||||
|
||||
languages = [
|
||||
settings.value(settings.Key.UI_LOCALE, QLocale().name())
|
||||
]
|
||||
languages = [custom_locale] if custom_locale else QLocale().uiLanguages()
|
||||
|
||||
logging.debug(f"UI locales {languages}")
|
||||
|
||||
translate = gettext.translation(
|
||||
APP_NAME.lower(), locale_dir, languages=languages, fallback=True
|
||||
)
|
||||
|
||||
_ = translate.gettext
|
||||
_ = translate.gettext
|
||||
|
|
|
|||
|
|
@ -7,23 +7,9 @@ import threading
|
|||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import ssl
|
||||
import tempfile
|
||||
import warnings
|
||||
import platform
|
||||
|
||||
# Fix SSL certificate verification for bundled applications (macOS, Windows).
|
||||
# This must be done before importing libraries that make HTTPS requests.
|
||||
try:
|
||||
import certifi
|
||||
_certifi_ca_bundle = certifi.where()
|
||||
os.environ.setdefault("REQUESTS_CA_BUNDLE", _certifi_ca_bundle)
|
||||
os.environ.setdefault("SSL_CERT_FILE", _certifi_ca_bundle)
|
||||
os.environ.setdefault("SSL_CERT_DIR", os.path.dirname(_certifi_ca_bundle))
|
||||
# Also update the default SSL context for urllib
|
||||
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=_certifi_ca_bundle)
|
||||
except ImportError:
|
||||
_certifi_ca_bundle = None
|
||||
|
||||
import requests
|
||||
import whisper
|
||||
import huggingface_hub
|
||||
|
|
@ -37,68 +23,16 @@ from huggingface_hub.errors import LocalEntryNotFoundError
|
|||
|
||||
from buzz.locale import _
|
||||
|
||||
# Configure huggingface_hub to use certifi certificates directly.
|
||||
# This is more reliable than environment variables for frozen apps.
|
||||
if _certifi_ca_bundle is not None:
|
||||
try:
|
||||
from huggingface_hub import configure_http_backend
|
||||
|
||||
def _hf_session_factory() -> requests.Session:
|
||||
session = requests.Session()
|
||||
session.verify = _certifi_ca_bundle
|
||||
return session
|
||||
|
||||
configure_http_backend(backend_factory=_hf_session_factory)
|
||||
except ImportError:
|
||||
# configure_http_backend not available in older huggingface_hub versions
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.debug(f"Failed to configure huggingface_hub HTTP backend: {e}")
|
||||
|
||||
# On Windows, creating symlinks requires special privileges (Developer Mode or
|
||||
# SeCreateSymbolicLinkPrivilege). Monkey-patch huggingface_hub to use file
|
||||
# copying instead of symlinks to avoid [WinError 1314] errors.
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
from huggingface_hub import file_download
|
||||
from pathlib import Path
|
||||
|
||||
_original_create_symlink = file_download._create_symlink
|
||||
|
||||
def _windows_create_symlink(src: Path, dst: Path, new_blob: bool = False) -> None:
|
||||
"""Windows-compatible replacement that copies instead of symlinking."""
|
||||
src = Path(src)
|
||||
dst = Path(dst)
|
||||
|
||||
# If dst already exists and is correct, skip
|
||||
if dst.exists():
|
||||
if dst.is_symlink():
|
||||
# Existing symlink - leave it
|
||||
return
|
||||
if dst.is_file():
|
||||
# Check if it's the same file
|
||||
if dst.stat().st_size == src.stat().st_size:
|
||||
return
|
||||
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Try symlink first (works if Developer Mode is enabled)
|
||||
try:
|
||||
dst.unlink(missing_ok=True)
|
||||
os.symlink(src, dst)
|
||||
return
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Fallback: copy the file instead
|
||||
dst.unlink(missing_ok=True)
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
file_download._create_symlink = _windows_create_symlink
|
||||
logging.debug("Patched huggingface_hub to use file copying on Windows")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch huggingface_hub for Windows: {e}")
|
||||
# Catch exception from whisper.dll not getting loaded.
|
||||
# TODO: Remove flag and try-except when issue with loading
|
||||
# the DLL in some envs is fixed.
|
||||
LOADED_WHISPER_CPP_BINARY = False
|
||||
try:
|
||||
import buzz.whisper_cpp as whisper_cpp # noqa: F401
|
||||
|
||||
LOADED_WHISPER_CPP_BINARY = True
|
||||
except ImportError:
|
||||
logging.exception("")
|
||||
|
||||
model_root_dir = user_cache_dir("Buzz")
|
||||
model_root_dir = os.path.join(model_root_dir, "models")
|
||||
|
|
@ -107,21 +41,17 @@ os.makedirs(model_root_dir, exist_ok=True)
|
|||
|
||||
logging.debug("Model root directory: %s", model_root_dir)
|
||||
|
||||
|
||||
class WhisperModelSize(str, enum.Enum):
|
||||
TINY = "tiny"
|
||||
TINYEN = "tiny.en"
|
||||
BASE = "base"
|
||||
BASEEN = "base.en"
|
||||
SMALL = "small"
|
||||
SMALLEN = "small.en"
|
||||
MEDIUM = "medium"
|
||||
MEDIUMEN = "medium.en"
|
||||
LARGE = "large"
|
||||
LARGEV2 = "large-v2"
|
||||
LARGEV3 = "large-v3"
|
||||
LARGEV3TURBO = "large-v3-turbo"
|
||||
CUSTOM = "custom"
|
||||
LUMII = "lumii"
|
||||
|
||||
def to_faster_whisper_model_size(self) -> str:
|
||||
if self == WhisperModelSize.LARGE:
|
||||
|
|
@ -136,25 +66,6 @@ class WhisperModelSize(str, enum.Enum):
|
|||
def __str__(self):
|
||||
return self.value.capitalize()
|
||||
|
||||
# Approximate expected file sizes for Whisper models (based on actual .pt file sizes)
|
||||
WHISPER_MODEL_SIZES = {
|
||||
WhisperModelSize.TINY: 72 * 1024 * 1024, # ~73 MB actual
|
||||
WhisperModelSize.TINYEN: 72 * 1024 * 1024, # ~73 MB actual
|
||||
WhisperModelSize.BASE: 138 * 1024 * 1024, # ~139 MB actual
|
||||
WhisperModelSize.BASEEN: 138 * 1024 * 1024, # ~139 MB actual
|
||||
WhisperModelSize.SMALL: 460 * 1024 * 1024, # ~462 MB actual
|
||||
WhisperModelSize.SMALLEN: 460 * 1024 * 1024, # ~462 MB actual
|
||||
WhisperModelSize.MEDIUM: 1500 * 1024 * 1024, # ~1.5 GB actual
|
||||
WhisperModelSize.MEDIUMEN: 1500 * 1024 * 1024, # ~1.5 GB actual
|
||||
WhisperModelSize.LARGE: 2870 * 1024 * 1024, # ~2.9 GB actual
|
||||
WhisperModelSize.LARGEV2: 2870 * 1024 * 1024, # ~2.9 GB actual
|
||||
WhisperModelSize.LARGEV3: 2870 * 1024 * 1024, # ~2.9 GB actual
|
||||
WhisperModelSize.LARGEV3TURBO: 1550 * 1024 * 1024, # ~1.6 GB actual (turbo is smaller)
|
||||
}
|
||||
|
||||
def get_expected_whisper_model_size(size: WhisperModelSize) -> Optional[int]:
|
||||
"""Get expected file size for a Whisper model without network request."""
|
||||
return WHISPER_MODEL_SIZES.get(size, None)
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
WHISPER = "Whisper"
|
||||
|
|
@ -174,6 +85,13 @@ class ModelType(enum.Enum):
|
|||
|
||||
def is_available(self):
|
||||
if (
|
||||
# Hide Whisper.cpp option if whisper.dll did not load correctly.
|
||||
# See: https://github.com/chidiwilliams/buzz/issues/274,
|
||||
# https://github.com/chidiwilliams/buzz/issues/197
|
||||
(self == ModelType.WHISPER_CPP and not LOADED_WHISPER_CPP_BINARY)
|
||||
):
|
||||
return False
|
||||
elif (
|
||||
# Hide Faster Whisper option on macOS x86_64
|
||||
# See: https://github.com/SYSTRAN/faster-whisper/issues/541
|
||||
(self == ModelType.FASTER_WHISPER
|
||||
|
|
@ -208,80 +126,6 @@ HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
|
|||
"vocab.json",
|
||||
]
|
||||
|
||||
# MMS models use different patterns - adapters are downloaded on-demand by transformers
|
||||
MMS_MODEL_ALLOW_PATTERNS = [
|
||||
"model.safetensors",
|
||||
"pytorch_model.bin",
|
||||
"config.json",
|
||||
"preprocessor_config.json",
|
||||
"tokenizer_config.json",
|
||||
"vocab.json",
|
||||
"special_tokens_map.json",
|
||||
"added_tokens.json",
|
||||
]
|
||||
|
||||
# ISO 639-1 to ISO 639-3 language code mapping for MMS models
|
||||
ISO_639_1_TO_3 = {
|
||||
"en": "eng", "fr": "fra", "de": "deu", "es": "spa", "it": "ita",
|
||||
"pt": "por", "ru": "rus", "ja": "jpn", "ko": "kor", "zh": "cmn",
|
||||
"ar": "ara", "hi": "hin", "nl": "nld", "pl": "pol", "sv": "swe",
|
||||
"tr": "tur", "uk": "ukr", "vi": "vie", "cs": "ces", "da": "dan",
|
||||
"fi": "fin", "el": "ell", "he": "heb", "hu": "hun", "id": "ind",
|
||||
"ms": "zsm", "no": "nob", "ro": "ron", "sk": "slk", "th": "tha",
|
||||
"bg": "bul", "ca": "cat", "hr": "hrv", "lt": "lit", "lv": "lav",
|
||||
"sl": "slv", "et": "est", "sr": "srp", "tl": "tgl", "bn": "ben",
|
||||
"ta": "tam", "te": "tel", "mr": "mar", "gu": "guj", "kn": "kan",
|
||||
"ml": "mal", "pa": "pan", "ur": "urd", "fa": "pes", "sw": "swh",
|
||||
"af": "afr", "az": "azj", "be": "bel", "bs": "bos", "cy": "cym",
|
||||
"eo": "epo", "eu": "eus", "ga": "gle", "gl": "glg", "hy": "hye",
|
||||
"is": "isl", "ka": "kat", "kk": "kaz", "km": "khm", "lo": "lao",
|
||||
"mk": "mkd", "mn": "khk", "my": "mya", "ne": "npi", "si": "sin",
|
||||
"sq": "sqi", "uz": "uzn", "zu": "zul", "am": "amh", "jw": "jav",
|
||||
"la": "lat", "so": "som", "su": "sun", "tt": "tat", "yo": "yor",
|
||||
}
|
||||
|
||||
|
||||
def map_language_to_mms(language_code: str) -> str:
|
||||
"""Convert ISO 639-1 code to ISO 639-3 code for MMS models.
|
||||
|
||||
If the code is already 3 letters, returns it as-is.
|
||||
If the code is not found in the mapping, returns as-is.
|
||||
"""
|
||||
if not language_code:
|
||||
return "eng" # Default to English for MMS
|
||||
if len(language_code) == 3:
|
||||
return language_code # Already ISO 639-3
|
||||
return ISO_639_1_TO_3.get(language_code, language_code)
|
||||
|
||||
|
||||
def is_mms_model(model_id: str) -> bool:
|
||||
"""Detect if a HuggingFace model is an MMS (Massively Multilingual Speech) model.
|
||||
|
||||
Detection criteria:
|
||||
1. Model ID contains "mms-" (e.g., facebook/mms-1b-all)
|
||||
2. Model config has model_type == "wav2vec2" with adapter architecture
|
||||
"""
|
||||
if not model_id:
|
||||
return False
|
||||
|
||||
# Fast check: model ID pattern
|
||||
if "mms-" in model_id.lower():
|
||||
return True
|
||||
|
||||
# For cached/downloaded models, check config.json
|
||||
try:
|
||||
import json
|
||||
config_path = huggingface_hub.hf_hub_download(
|
||||
model_id, "config.json", local_files_only=True, cache_dir=model_root_dir
|
||||
)
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
# MMS models have model_type "wav2vec2" and use adapter architecture
|
||||
return (config.get("model_type") == "wav2vec2"
|
||||
and config.get("adapter_attn_dim") is not None)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TranscriptionModel:
|
||||
|
|
@ -346,10 +190,8 @@ class TranscriptionModel:
|
|||
def delete_local_file(self):
|
||||
model_path = self.get_local_model_path()
|
||||
|
||||
if self.model_type in (ModelType.HUGGING_FACE,
|
||||
ModelType.FASTER_WHISPER):
|
||||
# Go up two directories to get the huggingface cache root for this model
|
||||
# Structure: models--repo--name/snapshots/xxx/files
|
||||
if (self.model_type == ModelType.HUGGING_FACE
|
||||
or self.model_type == ModelType.FASTER_WHISPER):
|
||||
model_path = os.path.dirname(os.path.dirname(model_path))
|
||||
|
||||
logging.debug("Deleting model directory: %s", model_path)
|
||||
|
|
@ -357,32 +199,6 @@ class TranscriptionModel:
|
|||
shutil.rmtree(model_path, ignore_errors=True)
|
||||
return
|
||||
|
||||
if self.model_type == ModelType.WHISPER_CPP:
|
||||
if self.whisper_model_size == WhisperModelSize.CUSTOM:
|
||||
# Custom models are stored as a single .bin file directly in model_root_dir
|
||||
logging.debug("Deleting model file: %s", model_path)
|
||||
os.remove(model_path)
|
||||
else:
|
||||
# Non-custom models are downloaded via huggingface_hub.
|
||||
# Multiple models share the same repo directory, so we only delete
|
||||
# the specific model files, not the entire directory.
|
||||
logging.debug("Deleting model file: %s", model_path)
|
||||
os.remove(model_path)
|
||||
|
||||
# Also delete CoreML files if they exist (.mlmodelc.zip and extracted directory)
|
||||
model_dir = os.path.dirname(model_path)
|
||||
model_name = self.whisper_model_size.to_whisper_cpp_model_size()
|
||||
coreml_zip = os.path.join(model_dir, f"ggml-{model_name}-encoder.mlmodelc.zip")
|
||||
coreml_dir = os.path.join(model_dir, f"ggml-{model_name}-encoder.mlmodelc")
|
||||
|
||||
if os.path.exists(coreml_zip):
|
||||
logging.debug("Deleting CoreML zip: %s", coreml_zip)
|
||||
os.remove(coreml_zip)
|
||||
if os.path.exists(coreml_dir):
|
||||
logging.debug("Deleting CoreML directory: %s", coreml_dir)
|
||||
shutil.rmtree(coreml_dir, ignore_errors=True)
|
||||
return
|
||||
|
||||
logging.debug("Deleting model file: %s", model_path)
|
||||
os.remove(model_path)
|
||||
|
||||
|
|
@ -397,21 +213,7 @@ class TranscriptionModel:
|
|||
file_path = get_whisper_file_path(size=self.whisper_model_size)
|
||||
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
||||
return None
|
||||
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
expected_size = get_expected_whisper_model_size(self.whisper_model_size)
|
||||
|
||||
if expected_size is not None:
|
||||
if file_size < expected_size * 0.95: # Allow 5% tolerance for file system differences
|
||||
return None
|
||||
return file_path
|
||||
else:
|
||||
# For unknown model size
|
||||
if file_size < 50 * 1024 * 1024:
|
||||
return None
|
||||
|
||||
return file_path
|
||||
return file_path
|
||||
|
||||
if self.model_type == ModelType.FASTER_WHISPER:
|
||||
try:
|
||||
|
|
@ -440,23 +242,17 @@ class TranscriptionModel:
|
|||
|
||||
|
||||
WHISPER_CPP_REPO_ID = "ggerganov/whisper.cpp"
|
||||
WHISPER_CPP_LUMII_REPO_ID = "RaivisDejus/whisper.cpp-lv"
|
||||
|
||||
|
||||
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
|
||||
if size == WhisperModelSize.CUSTOM:
|
||||
return os.path.join(model_root_dir, f"ggml-model-whisper-custom.bin")
|
||||
|
||||
repo_id = WHISPER_CPP_REPO_ID
|
||||
|
||||
if size == WhisperModelSize.LUMII:
|
||||
repo_id = WHISPER_CPP_LUMII_REPO_ID
|
||||
|
||||
model_filename = f"ggml-{size.to_whisper_cpp_model_size()}.bin"
|
||||
|
||||
try:
|
||||
model_path = huggingface_hub.snapshot_download(
|
||||
repo_id=repo_id,
|
||||
model_path = huggingface_hub.snapshot_download(
|
||||
repo_id=WHISPER_CPP_REPO_ID,
|
||||
allow_patterns=[model_filename],
|
||||
local_files_only=True,
|
||||
cache_dir=model_root_dir,
|
||||
|
|
@ -482,8 +278,7 @@ class HuggingfaceDownloadMonitor:
|
|||
def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size: int):
|
||||
self.model_root = model_root
|
||||
self.progress = progress
|
||||
# To keep dialog open even if it reports 100%
|
||||
self.total_file_size = round(total_file_size * 1.1)
|
||||
self.total_file_size = total_file_size
|
||||
self.incomplete_download_root = None
|
||||
self.stop_event = threading.Event()
|
||||
self.monitor_thread = None
|
||||
|
|
@ -491,10 +286,8 @@ class HuggingfaceDownloadMonitor:
|
|||
|
||||
def set_download_roots(self):
|
||||
normalized_model_root = os.path.normpath(self.model_root)
|
||||
two_dirs_up = os.path.normpath(
|
||||
os.path.join(normalized_model_root, "..", ".."))
|
||||
self.incomplete_download_root = os.path.normpath(
|
||||
os.path.join(two_dirs_up, "blobs"))
|
||||
two_dirs_up = os.path.normpath(os.path.join(normalized_model_root, "..", ".."))
|
||||
self.incomplete_download_root = os.path.normpath(os.path.join(two_dirs_up, "blobs"))
|
||||
|
||||
def clean_tmp_files(self):
|
||||
for filename in os.listdir(model_root_dir):
|
||||
|
|
@ -503,28 +296,16 @@ class HuggingfaceDownloadMonitor:
|
|||
|
||||
def monitor_file_size(self):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
if model_root_dir is not None and os.path.isdir(model_root_dir):
|
||||
for filename in os.listdir(model_root_dir):
|
||||
if filename.startswith("tmp"):
|
||||
try:
|
||||
file_size = os.path.getsize(
|
||||
os.path.join(model_root_dir, filename))
|
||||
self.progress.emit((file_size, self.total_file_size))
|
||||
except OSError:
|
||||
pass # File may have been deleted
|
||||
if model_root_dir is not None:
|
||||
for filename in os.listdir(model_root_dir):
|
||||
if filename.startswith("tmp"):
|
||||
file_size = os.path.getsize(os.path.join(model_root_dir, filename))
|
||||
self.progress.emit((file_size, self.total_file_size))
|
||||
|
||||
if self.incomplete_download_root and os.path.isdir(self.incomplete_download_root):
|
||||
for filename in os.listdir(self.incomplete_download_root):
|
||||
if filename.endswith(".incomplete"):
|
||||
try:
|
||||
file_size = os.path.getsize(os.path.join(
|
||||
self.incomplete_download_root, filename))
|
||||
self.progress.emit((file_size, self.total_file_size))
|
||||
except OSError:
|
||||
pass # File may have been deleted
|
||||
except OSError:
|
||||
pass # Directory listing failed, ignore
|
||||
for filename in os.listdir(self.incomplete_download_root):
|
||||
if filename.endswith(".incomplete"):
|
||||
file_size = os.path.getsize(os.path.join(self.incomplete_download_root, filename))
|
||||
self.progress.emit((file_size, self.total_file_size))
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
|
@ -558,8 +339,7 @@ def download_from_huggingface(
|
|||
try:
|
||||
model_root = huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
# all, but largest
|
||||
allow_patterns=allow_patterns[num_large_files:],
|
||||
allow_patterns=allow_patterns[num_large_files:], # all, but largest
|
||||
cache_dir=model_root_dir,
|
||||
etag_timeout=60
|
||||
)
|
||||
|
|
@ -581,8 +361,7 @@ def download_from_huggingface(
|
|||
except requests.exceptions.RequestException as e:
|
||||
continue
|
||||
|
||||
model_download_monitor = HuggingfaceDownloadMonitor(
|
||||
model_root, progress, largest_file_size)
|
||||
model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, largest_file_size)
|
||||
model_download_monitor.start_monitoring()
|
||||
|
||||
try:
|
||||
|
|
@ -595,7 +374,9 @@ def download_from_huggingface(
|
|||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
model_download_monitor.stop_monitoring()
|
||||
|
||||
# Cleanup to prevent incomplete downloads errors
|
||||
if os.path.exists(model_root):
|
||||
shutil.rmtree(model_root)
|
||||
return ""
|
||||
|
||||
model_download_monitor.stop_monitoring()
|
||||
|
|
@ -614,19 +395,23 @@ def download_faster_whisper_model(
|
|||
|
||||
if size == WhisperModelSize.CUSTOM:
|
||||
repo_id = custom_repo_id
|
||||
# Replicating models from faster-whisper code https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/utils.py#L29
|
||||
elif size == WhisperModelSize.LARGEV3:
|
||||
repo_id = "Systran/faster-whisper-large-v3"
|
||||
# Maybe switch to 'mobiuslabsgmbh/faster-whisper-large-v3-turbo', seems to be used in
|
||||
# faster-whisper code https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/utils.py#L29
|
||||
# If so changes needed also in whisper_file_transcriber.py
|
||||
elif size == WhisperModelSize.LARGEV3TURBO:
|
||||
repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
|
||||
repo_id = "deepdml/faster-whisper-large-v3-turbo-ct2"
|
||||
else:
|
||||
repo_id = "Systran/faster-whisper-%s" % size
|
||||
repo_id = "guillaumekln/faster-whisper-%s" % size
|
||||
|
||||
allow_patterns = [
|
||||
"model.bin", # largest by size first
|
||||
"pytorch_model.bin", # possible alternative model filename
|
||||
"config.json",
|
||||
"preprocessor_config.json",
|
||||
"tokenizer.json",
|
||||
"vocabulary.*",
|
||||
"vocabulary.txt",
|
||||
"vocabulary.json",
|
||||
]
|
||||
|
||||
if local_files_only:
|
||||
|
|
@ -655,29 +440,21 @@ class ModelDownloader(QRunnable):
|
|||
def __init__(self, model: TranscriptionModel, custom_model_url: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
||||
self.is_coreml_supported = platform.system(
|
||||
) == "Darwin" and platform.machine() == "arm64"
|
||||
self.is_coreml_supported = platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||
self.signals = self.Signals()
|
||||
self.model = model
|
||||
self.stopped = False
|
||||
self.custom_model_url = custom_model_url
|
||||
|
||||
def run(self) -> None:
|
||||
logging.debug("Downloading model: %s, %s", self.model,
|
||||
self.model.hugging_face_model_id)
|
||||
logging.debug("Downloading model: %s, %s", self.model, self.model.hugging_face_model_id)
|
||||
|
||||
if self.model.model_type == ModelType.WHISPER_CPP:
|
||||
if self.custom_model_url:
|
||||
url = self.custom_model_url
|
||||
file_path = get_whisper_cpp_file_path(
|
||||
size=self.model.whisper_model_size)
|
||||
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
|
||||
return self.download_model_to_path(url=url, file_path=file_path)
|
||||
|
||||
repo_id = WHISPER_CPP_REPO_ID
|
||||
|
||||
if self.model.whisper_model_size == WhisperModelSize.LUMII:
|
||||
repo_id = WHISPER_CPP_LUMII_REPO_ID
|
||||
|
||||
model_name = self.model.whisper_model_size.to_whisper_cpp_model_size()
|
||||
|
||||
whisper_cpp_model_files = [
|
||||
|
|
@ -687,64 +464,30 @@ class ModelDownloader(QRunnable):
|
|||
num_large_files = 1
|
||||
if self.is_coreml_supported:
|
||||
whisper_cpp_model_files = [
|
||||
f"ggml-{model_name}.bin",
|
||||
f"ggml-{model_name}-encoder.mlmodelc.zip",
|
||||
"README.md"
|
||||
f"ggml-{model_name}.bin",
|
||||
f"ggml-{model_name}-encoder.mlmodelc.zip",
|
||||
"README.md"
|
||||
]
|
||||
num_large_files = 2
|
||||
|
||||
model_path = download_from_huggingface(
|
||||
repo_id=repo_id,
|
||||
repo_id=WHISPER_CPP_REPO_ID,
|
||||
allow_patterns=whisper_cpp_model_files,
|
||||
progress=self.signals.progress,
|
||||
num_large_files=num_large_files
|
||||
)
|
||||
|
||||
if self.is_coreml_supported:
|
||||
import tempfile
|
||||
with zipfile.ZipFile(
|
||||
os.path.join(model_path, f"ggml-{model_name}-encoder.mlmodelc.zip"), 'r') as zip_ref:
|
||||
zip_ref.extractall(model_path)
|
||||
|
||||
target_dir = os.path.join(model_path, f"ggml-{model_name}-encoder.mlmodelc")
|
||||
zip_path = os.path.join(model_path, f"ggml-{model_name}-encoder.mlmodelc.zip")
|
||||
|
||||
# Remove target directory if it exists
|
||||
if os.path.exists(target_dir):
|
||||
shutil.rmtree(target_dir)
|
||||
|
||||
# Extract to a temporary directory first
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
# Remove __MACOSX metadata folders if present
|
||||
macosx_path = os.path.join(temp_dir, "__MACOSX")
|
||||
if os.path.exists(macosx_path):
|
||||
shutil.rmtree(macosx_path)
|
||||
|
||||
# Check if there's a single top-level directory
|
||||
temp_contents = os.listdir(temp_dir)
|
||||
if len(temp_contents) == 1 and os.path.isdir(os.path.join(temp_dir, temp_contents[0])):
|
||||
# Single directory - move its contents to target
|
||||
nested_dir = os.path.join(temp_dir, temp_contents[0])
|
||||
shutil.move(nested_dir, target_dir)
|
||||
else:
|
||||
# Multiple items or files - copy everything to target
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
for item in temp_contents:
|
||||
src = os.path.join(temp_dir, item)
|
||||
dst = os.path.join(target_dir, item)
|
||||
if os.path.isdir(src):
|
||||
shutil.copytree(src, dst)
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
self.signals.finished.emit(os.path.join(
|
||||
model_path, f"ggml-{model_name}.bin"))
|
||||
self.signals.finished.emit(os.path.join(model_path, f"ggml-{model_name}.bin"))
|
||||
return
|
||||
|
||||
if self.model.model_type == ModelType.WHISPER:
|
||||
url = whisper._MODELS[self.model.whisper_model_size.value]
|
||||
file_path = get_whisper_file_path(
|
||||
size=self.model.whisper_model_size)
|
||||
file_path = get_whisper_file_path(size=self.model.whisper_model_size)
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
return self.download_model_to_path(
|
||||
url=url, file_path=file_path, expected_sha256=expected_sha256
|
||||
|
|
@ -769,10 +512,6 @@ class ModelDownloader(QRunnable):
|
|||
progress=self.signals.progress,
|
||||
num_large_files=4
|
||||
)
|
||||
|
||||
if model_path == "":
|
||||
self.signals.error.emit(_("Error"))
|
||||
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
||||
|
|
@ -789,18 +528,16 @@ class ModelDownloader(QRunnable):
|
|||
downloaded = self.download_model(url, file_path, expected_sha256)
|
||||
if downloaded:
|
||||
self.signals.finished.emit(file_path)
|
||||
except requests.RequestException as e:
|
||||
except requests.RequestException:
|
||||
self.signals.error.emit(_("A connection error occurred"))
|
||||
if not self.stopped and "timeout" not in str(e).lower():
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logging.exception("")
|
||||
except Exception as exc:
|
||||
self.signals.error.emit(str(exc))
|
||||
if not self.stopped:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logging.exception(exc)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logging.exception(exc)
|
||||
|
||||
def download_model(
|
||||
self, url: str, file_path: str, expected_sha256: Optional[str]
|
||||
|
|
@ -812,190 +549,41 @@ class ModelDownloader(QRunnable):
|
|||
if os.path.exists(file_path) and not os.path.isfile(file_path):
|
||||
raise RuntimeError(f"{file_path} exists and is not a regular file")
|
||||
|
||||
resume_from = 0
|
||||
file_mode = "wb"
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
file_size = os.path.getsize(file_path)
|
||||
if expected_sha256 is None:
|
||||
return True
|
||||
|
||||
if expected_sha256 is not None:
|
||||
# Get the expected file size from URL
|
||||
try:
|
||||
head_response = requests.head(url, timeout=5, allow_redirects=True)
|
||||
expected_size = int(head_response.headers.get("Content-Length", 0))
|
||||
|
||||
if expected_size > 0:
|
||||
if file_size < expected_size:
|
||||
resume_from = file_size
|
||||
file_mode = "ab"
|
||||
logging.debug(
|
||||
f"File incomplete ({file_size}/{expected_size} bytes), resuming from byte {resume_from}"
|
||||
)
|
||||
elif file_size == expected_size:
|
||||
# This means file size matches - verify SHA256 to confirm it is complete
|
||||
try:
|
||||
# Use chunked reading to avoid loading entire file into memory
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256_hash.update(chunk)
|
||||
model_sha256 = sha256_hash.hexdigest()
|
||||
if model_sha256 == expected_sha256:
|
||||
logging.debug("Model already downloaded and verified")
|
||||
return True
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
# File exists but it is wrong, delete it
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
logging.warning(f"Error checking existing file: {e}")
|
||||
os.remove(file_path)
|
||||
else:
|
||||
# File is larger than expected - corrupted, delete it
|
||||
warnings.warn(f"File size ({file_size}) exceeds expected size ({expected_size}), re-downloading")
|
||||
os.remove(file_path)
|
||||
else:
|
||||
# Can't get expected size - use threshold approach
|
||||
if file_size < 10 * 1024 * 1024:
|
||||
resume_from = file_size
|
||||
file_mode = "ab" # Append mode to resume
|
||||
logging.debug(f"Resuming download from byte {resume_from}")
|
||||
else:
|
||||
# Large file - verify SHA256 using chunked reading
|
||||
try:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256_hash.update(chunk)
|
||||
model_sha256 = sha256_hash.hexdigest()
|
||||
if model_sha256 == expected_sha256:
|
||||
logging.debug("Model already downloaded and verified")
|
||||
return True
|
||||
else:
|
||||
warnings.warn("SHA256 mismatch, re-downloading")
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
logging.warning(f"Error verifying file: {e}")
|
||||
os.remove(file_path)
|
||||
|
||||
except Exception as e:
|
||||
# Can't get expected size - use threshold
|
||||
logging.debug(f"Could not get expected file size: {e}, using threshold")
|
||||
if file_size < 10 * 1024 * 1024:
|
||||
resume_from = file_size
|
||||
file_mode = "ab"
|
||||
logging.debug(f"Resuming from byte {resume_from}")
|
||||
model_bytes = open(file_path, "rb").read()
|
||||
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
|
||||
if model_sha256 == expected_sha256:
|
||||
return True
|
||||
else:
|
||||
# No SHA256 to verify - just check file size
|
||||
if file_size > 0:
|
||||
resume_from = file_size
|
||||
file_mode = "ab"
|
||||
logging.debug(f"Resuming download from byte {resume_from}")
|
||||
warnings.warn(
|
||||
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
tmp_file = tempfile.mktemp()
|
||||
logging.debug("Downloading to temporary file = %s", tmp_file)
|
||||
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
|
||||
# Check if server supports Range requests before starting download
|
||||
supports_range = False
|
||||
if resume_from > 0:
|
||||
try:
|
||||
head_resp = requests.head(url, timeout=10, allow_redirects=True)
|
||||
accept_ranges = head_resp.headers.get("Accept-Ranges", "").lower()
|
||||
supports_range = accept_ranges == "bytes"
|
||||
if not supports_range:
|
||||
logging.debug("Server doesn't support Range requests, starting from beginning")
|
||||
resume_from = 0
|
||||
file_mode = "wb"
|
||||
except requests.RequestException as e:
|
||||
logging.debug(f"HEAD request failed, starting fresh: {e}")
|
||||
resume_from = 0
|
||||
file_mode = "wb"
|
||||
|
||||
headers = {}
|
||||
if resume_from > 0 and supports_range:
|
||||
headers["Range"] = f"bytes={resume_from}-"
|
||||
|
||||
# Use a temporary file for fresh downloads to ensure atomic writes
|
||||
temp_file_path = None
|
||||
if resume_from == 0:
|
||||
temp_file_path = file_path + ".downloading"
|
||||
# Clean up any existing temp file
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
except OSError:
|
||||
pass
|
||||
download_path = temp_file_path
|
||||
else:
|
||||
download_path = file_path
|
||||
|
||||
try:
|
||||
with requests.get(url, stream=True, timeout=30, headers=headers) as source:
|
||||
source.raise_for_status()
|
||||
|
||||
if resume_from > 0:
|
||||
if source.status_code == 206:
|
||||
logging.debug(
|
||||
f"Server supports resume, continuing from byte {resume_from}")
|
||||
content_range = source.headers.get("Content-Range", "")
|
||||
if "/" in content_range:
|
||||
total_size = int(content_range.split("/")[-1])
|
||||
else:
|
||||
total_size = resume_from + int(source.headers.get("Content-Length", 0))
|
||||
current = resume_from
|
||||
else:
|
||||
# Server returned 200 instead of 206, need to start over
|
||||
logging.debug("Server returned 200 instead of 206, starting fresh")
|
||||
resume_from = 0
|
||||
file_mode = "wb"
|
||||
temp_file_path = file_path + ".downloading"
|
||||
download_path = temp_file_path
|
||||
total_size = float(source.headers.get("Content-Length", 0))
|
||||
current = 0.0
|
||||
else:
|
||||
total_size = float(source.headers.get("Content-Length", 0))
|
||||
current = 0.0
|
||||
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(
|
||||
tmp_file, "wb"
|
||||
) as output:
|
||||
source.raise_for_status()
|
||||
total_size = float(source.headers.get("Content-Length", 0))
|
||||
current = 0.0
|
||||
self.signals.progress.emit((current, total_size))
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
if self.stopped:
|
||||
return False
|
||||
output.write(chunk)
|
||||
current += len(chunk)
|
||||
self.signals.progress.emit((current, total_size))
|
||||
|
||||
with open(download_path, file_mode) as output:
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
if self.stopped:
|
||||
return False
|
||||
output.write(chunk)
|
||||
current += len(chunk)
|
||||
self.signals.progress.emit((current, total_size))
|
||||
|
||||
# If we used a temp file, rename it to the final path
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
# Remove existing file if present
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
shutil.move(temp_file_path, file_path)
|
||||
|
||||
except Exception:
|
||||
# Clean up temp file on error
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
if expected_sha256 is not None:
|
||||
# Use chunked reading to avoid loading entire file into memory
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256_hash.update(chunk)
|
||||
if sha256_hash.hexdigest() != expected_sha256:
|
||||
# Delete the corrupted file before raising the error
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except OSError as e:
|
||||
logging.warning(f"Failed to delete corrupted model file: {e}")
|
||||
model_bytes = open(tmp_file, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
|
||||
"model."
|
||||
|
|
@ -1003,7 +591,17 @@ class ModelDownloader(QRunnable):
|
|||
|
||||
logging.debug("Downloaded model")
|
||||
|
||||
# https://github.com/chidiwilliams/buzz/issues/454
|
||||
shutil.move(tmp_file, file_path)
|
||||
logging.debug("Moved file from %s to %s", tmp_file, file_path)
|
||||
return True
|
||||
|
||||
def cancel(self):
|
||||
self.stopped = True
|
||||
|
||||
|
||||
def get_custom_api_whisper_model(base_url: str):
|
||||
if "api.groq.com" in base_url:
|
||||
return "whisper-large-v3"
|
||||
|
||||
return "whisper-1"
|
||||
|
|
|
|||
|
|
@ -9,9 +9,6 @@ from PyQt6.QtCore import QObject, pyqtSignal
|
|||
class RecordingAmplitudeListener(QObject):
|
||||
stream: Optional[sounddevice.InputStream] = None
|
||||
amplitude_changed = pyqtSignal(float)
|
||||
average_amplitude_changed = pyqtSignal(float)
|
||||
|
||||
ACCUMULATION_SECONDS = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -20,9 +17,6 @@ class RecordingAmplitudeListener(QObject):
|
|||
):
|
||||
super().__init__(parent)
|
||||
self.input_device_index = input_device_index
|
||||
self.buffer = np.ndarray([], dtype=np.float32)
|
||||
self.accumulation_size = 0
|
||||
self._active = True
|
||||
|
||||
def start_recording(self):
|
||||
try:
|
||||
|
|
@ -33,24 +27,16 @@ class RecordingAmplitudeListener(QObject):
|
|||
callback=self.stream_callback,
|
||||
)
|
||||
self.stream.start()
|
||||
self.accumulation_size = int(self.stream.samplerate * self.ACCUMULATION_SECONDS)
|
||||
except Exception as e:
|
||||
except sounddevice.PortAudioError:
|
||||
self.stop_recording()
|
||||
logging.exception("Failed to start audio stream on device %s: %s", self.input_device_index, e)
|
||||
logging.exception("")
|
||||
|
||||
def stop_recording(self):
|
||||
self._active = False
|
||||
if self.stream is not None:
|
||||
self.stream.stop()
|
||||
self.stream.close()
|
||||
|
||||
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
|
||||
if not self._active:
|
||||
return
|
||||
chunk = in_data.ravel()
|
||||
self.amplitude_changed.emit(float(np.sqrt(np.mean(chunk**2))))
|
||||
|
||||
self.buffer = np.append(self.buffer, chunk)
|
||||
if self.buffer.size >= self.accumulation_size:
|
||||
self.average_amplitude_changed.emit(float(np.sqrt(np.mean(self.buffer**2))))
|
||||
self.buffer = np.ndarray([], dtype=np.float32)
|
||||
amplitude = np.sqrt(np.mean(chunk**2)) # root-mean-square
|
||||
self.amplitude_changed.emit(amplitude)
|
||||
|
|
|
|||
|
|
@ -15,11 +15,7 @@ CREATE TABLE transcription (
|
|||
time_started TIMESTAMP,
|
||||
url TEXT,
|
||||
whisper_model_size TEXT,
|
||||
hugging_face_model_id TEXT,
|
||||
word_level_timings BOOLEAN DEFAULT FALSE,
|
||||
extract_speech BOOLEAN DEFAULT FALSE,
|
||||
name TEXT,
|
||||
notes TEXT
|
||||
hugging_face_model_id TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE transcription_segment (
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import enum
|
||||
import typing
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from PyQt6.QtCore import QSettings
|
||||
|
||||
|
|
@ -12,11 +11,13 @@ class Settings:
|
|||
def __init__(self, application=""):
|
||||
self.settings = QSettings(APP_NAME, application)
|
||||
self.settings.sync()
|
||||
logging.debug(f"Settings filename: {self.settings.fileName()}")
|
||||
|
||||
class Key(enum.Enum):
|
||||
RECORDING_TRANSCRIBER_TASK = "recording-transcriber/task"
|
||||
RECORDING_TRANSCRIBER_MODEL = "recording-transcriber/model"
|
||||
RECORDING_TRANSCRIBER_LANGUAGE = "recording-transcriber/language"
|
||||
RECORDING_TRANSCRIBER_TEMPERATURE = "recording-transcriber/temperature"
|
||||
RECORDING_TRANSCRIBER_INITIAL_PROMPT = "recording-transcriber/initial-prompt"
|
||||
RECORDING_TRANSCRIBER_ENABLE_LLM_TRANSLATION = "recording-transcriber/enable-llm-translation"
|
||||
RECORDING_TRANSCRIBER_LLM_MODEL = "recording-transcriber/llm-model"
|
||||
|
|
@ -24,22 +25,11 @@ class Settings:
|
|||
RECORDING_TRANSCRIBER_EXPORT_ENABLED = "recording-transcriber/export-enabled"
|
||||
RECORDING_TRANSCRIBER_EXPORT_FOLDER = "recording-transcriber/export-folder"
|
||||
RECORDING_TRANSCRIBER_MODE = "recording-transcriber/mode"
|
||||
RECORDING_TRANSCRIBER_SILENCE_THRESHOLD = "recording-transcriber/silence-threshold"
|
||||
RECORDING_TRANSCRIBER_LINE_SEPARATOR = "recording-transcriber/line-separator"
|
||||
RECORDING_TRANSCRIBER_TRANSCRIPTION_STEP = "recording-transcriber/transcription-step"
|
||||
RECORDING_TRANSCRIBER_EXPORT_FILE_TYPE = "recording-transcriber/export-file-type"
|
||||
RECORDING_TRANSCRIBER_EXPORT_MAX_ENTRIES = "recording-transcriber/export-max-entries"
|
||||
RECORDING_TRANSCRIBER_EXPORT_FILE_NAME = "recording-transcriber/export-file-name"
|
||||
RECORDING_TRANSCRIBER_HIDE_UNCONFIRMED = "recording-transcriber/hide-unconfirmed"
|
||||
|
||||
PRESENTATION_WINDOW_TEXT_COLOR = "presentation-window/text-color"
|
||||
PRESENTATION_WINDOW_BACKGROUND_COLOR = "presentation-window/background-color"
|
||||
PRESENTATION_WINDOW_TEXT_SIZE = "presentation-window/text-size"
|
||||
PRESENTATION_WINDOW_THEME = "presentation-window/theme"
|
||||
|
||||
FILE_TRANSCRIBER_TASK = "file-transcriber/task"
|
||||
FILE_TRANSCRIBER_MODEL = "file-transcriber/model"
|
||||
FILE_TRANSCRIBER_LANGUAGE = "file-transcriber/language"
|
||||
FILE_TRANSCRIBER_TEMPERATURE = "file-transcriber/temperature"
|
||||
FILE_TRANSCRIBER_INITIAL_PROMPT = "file-transcriber/initial-prompt"
|
||||
FILE_TRANSCRIBER_ENABLE_LLM_TRANSLATION = "file-transcriber/enable-llm-translation"
|
||||
FILE_TRANSCRIBER_LLM_MODEL = "file-transcriber/llm-model"
|
||||
|
|
@ -49,7 +39,6 @@ class Settings:
|
|||
|
||||
DEFAULT_EXPORT_FILE_NAME = "transcriber/default-export-file-name"
|
||||
CUSTOM_OPENAI_BASE_URL = "transcriber/custom-openai-base-url"
|
||||
OPENAI_API_MODEL = "transcriber/openai-api-model"
|
||||
CUSTOM_FASTER_WHISPER_ID = "transcriber/custom-faster-whisper-id"
|
||||
HUGGINGFACE_MODEL_ID = "transcriber/huggingface-model-id"
|
||||
|
||||
|
|
@ -57,40 +46,11 @@ class Settings:
|
|||
|
||||
FONT_SIZE = "font-size"
|
||||
|
||||
UI_LOCALE = "ui-locale"
|
||||
|
||||
USER_IDENTIFIER = "user-identifier"
|
||||
|
||||
TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY = (
|
||||
"transcription-tasks-table/column-visibility"
|
||||
)
|
||||
TRANSCRIPTION_TASKS_TABLE_COLUMN_ORDER = (
|
||||
"transcription-tasks-table/column-order"
|
||||
)
|
||||
TRANSCRIPTION_TASKS_TABLE_COLUMN_WIDTHS = (
|
||||
"transcription-tasks-table/column-widths"
|
||||
)
|
||||
TRANSCRIPTION_TASKS_TABLE_SORT_STATE = (
|
||||
"transcription-tasks-table/sort-state"
|
||||
)
|
||||
|
||||
MAIN_WINDOW = "main-window"
|
||||
TRANSCRIPTION_VIEWER = "transcription-viewer"
|
||||
|
||||
AUDIO_PLAYBACK_RATE = "audio/playback-rate"
|
||||
|
||||
FORCE_CPU = "force-cpu"
|
||||
REDUCE_GPU_MEMORY = "reduce-gpu-memory"
|
||||
|
||||
LAST_UPDATE_CHECK = "update/last-check"
|
||||
UPDATE_AVAILABLE_VERSION = "update/available-version"
|
||||
|
||||
def get_user_identifier(self) -> str:
|
||||
user_id = self.value(self.Key.USER_IDENTIFIER, "")
|
||||
if not user_id:
|
||||
user_id = str(uuid.uuid4())
|
||||
self.set_value(self.Key.USER_IDENTIFIER, user_id)
|
||||
return user_id
|
||||
|
||||
def set_value(self, key: Key, value: typing.Any) -> None:
|
||||
self.settings.setValue(key.value, value)
|
||||
|
|
@ -126,25 +86,16 @@ class Settings:
|
|||
return ""
|
||||
|
||||
def value(
|
||||
self,
|
||||
key: Key,
|
||||
default_value: typing.Any,
|
||||
value_type: typing.Optional[type] = None,
|
||||
self,
|
||||
key: Key,
|
||||
default_value: typing.Any,
|
||||
value_type: typing.Optional[type] = None,
|
||||
) -> typing.Any:
|
||||
val = self.settings.value(
|
||||
return self.settings.value(
|
||||
key.value,
|
||||
default_value,
|
||||
value_type if value_type is not None else type(default_value),
|
||||
)
|
||||
if (value_type is bool or isinstance(default_value, bool)):
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
if isinstance(val, str):
|
||||
return val.lower() in ("true", "1", "yes", "on")
|
||||
if isinstance(val, int):
|
||||
return val != 0
|
||||
return bool(val)
|
||||
return val
|
||||
|
||||
def clear(self):
|
||||
self.settings.clear()
|
||||
|
|
|
|||
|
|
@ -22,18 +22,6 @@ class Shortcut(str, enum.Enum):
|
|||
VIEW_TRANSCRIPT_TEXT = ("Ctrl+E", _("View Transcript Text"))
|
||||
VIEW_TRANSCRIPT_TRANSLATION = ("Ctrl+L", _("View Transcript Translation"))
|
||||
VIEW_TRANSCRIPT_TIMESTAMPS = ("Ctrl+T", _("View Transcript Timestamps"))
|
||||
SEARCH_TRANSCRIPT = ("Ctrl+F", _("Search Transcript"))
|
||||
SEARCH_NEXT = ("Ctrl+Return", _("Go to Next Transcript Search Result"))
|
||||
SEARCH_PREVIOUS = ("Shift+Return", _("Go to Previous Transcript Search Result"))
|
||||
SCROLL_TO_CURRENT_TEXT = ("Ctrl+G", _("Scroll to Current Text"))
|
||||
PLAY_PAUSE_AUDIO = ("Ctrl+P", _("Play/Pause Audio"))
|
||||
REPLAY_CURRENT_SEGMENT = ("Ctrl+Shift+P", _("Replay Current Segment"))
|
||||
TOGGLE_PLAYBACK_CONTROLS = ("Ctrl+Alt+P", _("Toggle Playback Controls"))
|
||||
|
||||
DECREASE_SEGMENT_START = ("Ctrl+Left", _("Decrease Segment Start Time"))
|
||||
INCREASE_SEGMENT_START = ("Ctrl+Right", _("Increase Segment Start Time"))
|
||||
DECREASE_SEGMENT_END = ("Ctrl+Shift+Left", _("Decrease Segment End Time"))
|
||||
INCREASE_SEGMENT_END = ("Ctrl+Shift+Right", _("Increase Segment End Time"))
|
||||
|
||||
CLEAR_HISTORY = ("Ctrl+S", _("Clear History"))
|
||||
STOP_TRANSCRIPTION = ("Ctrl+X", _("Cancel Transcription"))
|
||||
|
|
|
|||
|
|
@ -1,10 +1,5 @@
|
|||
import base64
|
||||
import enum
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import keyring
|
||||
|
||||
|
|
@ -15,199 +10,7 @@ class Key(enum.Enum):
|
|||
OPENAI_API_KEY = "OpenAI API key"
|
||||
|
||||
|
||||
def _is_linux() -> bool:
|
||||
return sys.platform.startswith("linux")
|
||||
|
||||
|
||||
def _get_secrets_file_path() -> str:
|
||||
"""Get the path to the local encrypted secrets file."""
|
||||
from platformdirs import user_data_dir
|
||||
|
||||
data_dir = user_data_dir(APP_NAME)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
return os.path.join(data_dir, ".secrets.json")
|
||||
|
||||
|
||||
def _get_portal_secret() -> bytes | None:
|
||||
"""Get the application secret from XDG Desktop Portal.
|
||||
|
||||
The portal provides a per-application secret that can be used
|
||||
for encrypting application-specific data. This works in sandboxed
|
||||
environments (Snap/Flatpak) via the desktop plug.
|
||||
"""
|
||||
if not _is_linux():
|
||||
return None
|
||||
|
||||
try:
|
||||
from jeepney import DBusAddress, new_method_call
|
||||
from jeepney.io.blocking import open_dbus_connection
|
||||
import socket
|
||||
|
||||
# Open connection with file descriptor support enabled
|
||||
conn = open_dbus_connection(bus="SESSION", enable_fds=True)
|
||||
|
||||
portal = DBusAddress(
|
||||
"/org/freedesktop/portal/desktop",
|
||||
bus_name="org.freedesktop.portal.Desktop",
|
||||
interface="org.freedesktop.portal.Secret",
|
||||
)
|
||||
|
||||
# Create a socket pair for receiving the secret
|
||||
sock_read, sock_write = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
|
||||
try:
|
||||
# Build the method call with file descriptor
|
||||
# RetrieveSecret(fd: h, options: a{sv}) -> (handle: o)
|
||||
# Pass the socket object directly - jeepney handles fd passing
|
||||
msg = new_method_call(portal, "RetrieveSecret", "ha{sv}", (sock_write, {}))
|
||||
|
||||
# Send message and get reply
|
||||
conn.send_and_get_reply(msg, timeout=10)
|
||||
|
||||
# Close the write end - portal has it now
|
||||
sock_write.close()
|
||||
sock_write = None
|
||||
|
||||
# Read the secret from the read end
|
||||
# The portal writes the secret and closes its end
|
||||
sock_read.settimeout(5.0)
|
||||
secret_data = b""
|
||||
while True:
|
||||
try:
|
||||
chunk = sock_read.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
secret_data += chunk
|
||||
except socket.timeout:
|
||||
break
|
||||
|
||||
if secret_data:
|
||||
return secret_data
|
||||
|
||||
return None
|
||||
|
||||
finally:
|
||||
sock_read.close()
|
||||
if sock_write is not None:
|
||||
sock_write.close()
|
||||
|
||||
except Exception as exc:
|
||||
logging.debug("XDG Portal secret not available: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _derive_key(master_secret: bytes, key_name: str) -> bytes:
|
||||
"""Derive a key-specific encryption key from the master secret."""
|
||||
# Use PBKDF2 to derive a key for this specific secret
|
||||
return hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
master_secret,
|
||||
f"{APP_NAME}:{key_name}".encode(),
|
||||
100000,
|
||||
dklen=32,
|
||||
)
|
||||
|
||||
|
||||
def _encrypt_value(value: str, key: bytes) -> str:
|
||||
"""Encrypt a value using XOR with the derived key (simple encryption)."""
|
||||
# For a more secure implementation, use cryptography library with AES
|
||||
# This is a simple XOR-based encryption suitable for the use case
|
||||
value_bytes = value.encode("utf-8")
|
||||
key_extended = (key * ((len(value_bytes) // len(key)) + 1))[: len(value_bytes)]
|
||||
encrypted = bytes(a ^ b for a, b in zip(value_bytes, key_extended))
|
||||
return base64.b64encode(encrypted).decode("ascii")
|
||||
|
||||
|
||||
def _decrypt_value(encrypted: str, key: bytes) -> str:
|
||||
"""Decrypt a value using XOR with the derived key."""
|
||||
encrypted_bytes = base64.b64decode(encrypted.encode("ascii"))
|
||||
key_extended = (key * ((len(encrypted_bytes) // len(key)) + 1))[: len(encrypted_bytes)]
|
||||
decrypted = bytes(a ^ b for a, b in zip(encrypted_bytes, key_extended))
|
||||
return decrypted.decode("utf-8")
|
||||
|
||||
|
||||
def _load_local_secrets() -> dict:
|
||||
"""Load the local secrets file."""
|
||||
secrets_file = _get_secrets_file_path()
|
||||
if os.path.exists(secrets_file):
|
||||
try:
|
||||
with open(secrets_file, "r") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as exc:
|
||||
logging.debug("Failed to load secrets file: %s", exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _save_local_secrets(secrets: dict) -> None:
|
||||
"""Save secrets to the local file."""
|
||||
secrets_file = _get_secrets_file_path()
|
||||
try:
|
||||
with open(secrets_file, "w") as f:
|
||||
json.dump(secrets, f)
|
||||
# Set restrictive permissions
|
||||
os.chmod(secrets_file, 0o600)
|
||||
except IOError as exc:
|
||||
logging.warning("Failed to save secrets file: %s", exc)
|
||||
|
||||
|
||||
def _get_portal_password(key: Key) -> str | None:
|
||||
"""Get a password using the XDG Desktop Portal Secret."""
|
||||
portal_secret = _get_portal_secret()
|
||||
if portal_secret is None:
|
||||
return None
|
||||
|
||||
secrets = _load_local_secrets()
|
||||
encrypted_value = secrets.get(key.value)
|
||||
if encrypted_value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
derived_key = _derive_key(portal_secret, key.value)
|
||||
return _decrypt_value(encrypted_value, derived_key)
|
||||
except Exception as exc:
|
||||
logging.debug("Failed to decrypt portal secret: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _set_portal_password(key: Key, password: str) -> bool:
|
||||
"""Set a password using the XDG Desktop Portal Secret."""
|
||||
portal_secret = _get_portal_secret()
|
||||
if portal_secret is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
derived_key = _derive_key(portal_secret, key.value)
|
||||
encrypted_value = _encrypt_value(password, derived_key)
|
||||
|
||||
secrets = _load_local_secrets()
|
||||
secrets[key.value] = encrypted_value
|
||||
_save_local_secrets(secrets)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logging.debug("Failed to set portal secret: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
def _delete_portal_password(key: Key) -> bool:
|
||||
"""Delete a password from the portal-based local storage."""
|
||||
secrets = _load_local_secrets()
|
||||
if key.value in secrets:
|
||||
del secrets[key.value]
|
||||
_save_local_secrets(secrets)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_password(key: Key) -> str | None:
|
||||
# On Linux, try XDG Desktop Portal first (works in sandboxed environments)
|
||||
if _is_linux():
|
||||
result = _get_portal_password(key)
|
||||
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Fall back to keyring (cross-platform, uses Secret Service on Linux)
|
||||
try:
|
||||
password = keyring.get_password(APP_NAME, username=key.value)
|
||||
if password is None:
|
||||
|
|
@ -219,25 +22,4 @@ def get_password(key: Key) -> str | None:
|
|||
|
||||
|
||||
def set_password(username: Key, password: str) -> None:
|
||||
# On Linux, try XDG Desktop Portal first (works in sandboxed environments)
|
||||
if _is_linux():
|
||||
if _set_portal_password(username, password):
|
||||
return
|
||||
|
||||
# Fall back to keyring (cross-platform, uses Secret Service on Linux)
|
||||
keyring.set_password(APP_NAME, username.value, password)
|
||||
|
||||
|
||||
def delete_password(key: Key) -> None:
|
||||
"""Delete a password from the secret store."""
|
||||
# On Linux, also delete from portal storage
|
||||
if _is_linux():
|
||||
_delete_portal_password(key)
|
||||
|
||||
# Delete from keyring
|
||||
try:
|
||||
keyring.delete_password(APP_NAME, key.value)
|
||||
except keyring.errors.PasswordDeleteError:
|
||||
pass # Password doesn't exist, ignore
|
||||
except Exception as exc:
|
||||
logging.warning("Unable to delete from keyring: %s", exc)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,15 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import shutil
|
||||
import tempfile
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
|
||||
from yt_dlp import YoutubeDL
|
||||
|
||||
from buzz import whisper_audio
|
||||
from buzz.assets import APP_BASE_DIR
|
||||
from buzz.whisper_audio import SAMPLE_RATE
|
||||
from buzz.transcriber.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
get_output_file_path,
|
||||
|
|
@ -20,9 +17,6 @@ from buzz.transcriber.transcriber import (
|
|||
OutputFormat,
|
||||
)
|
||||
|
||||
app_env = os.environ.copy()
|
||||
app_env['PATH'] = os.pathsep.join([os.path.join(APP_BASE_DIR, "_internal")] + [app_env['PATH']])
|
||||
|
||||
|
||||
class FileTranscriber(QObject):
|
||||
transcription_task: FileTranscriptionTask
|
||||
|
|
@ -38,34 +32,10 @@ class FileTranscriber(QObject):
|
|||
@pyqtSlot()
|
||||
def run(self):
|
||||
if self.transcription_task.source == FileTranscriptionTask.Source.URL_IMPORT:
|
||||
cookiefile = os.getenv("BUZZ_DOWNLOAD_COOKIEFILE")
|
||||
|
||||
# First extract info to get the video title
|
||||
extract_options = {
|
||||
"logger": logging.getLogger(),
|
||||
}
|
||||
if cookiefile:
|
||||
extract_options["cookiefile"] = cookiefile
|
||||
|
||||
try:
|
||||
with YoutubeDL(extract_options) as ydl_info:
|
||||
info = ydl_info.extract_info(self.transcription_task.url, download=False)
|
||||
video_title = info.get("title", "audio")
|
||||
except Exception as exc:
|
||||
logging.debug(f"Error extracting video info: {exc}")
|
||||
video_title = "audio"
|
||||
|
||||
# Sanitize title for use as filename
|
||||
video_title = YoutubeDL.sanitize_info({"title": video_title})["title"]
|
||||
# Remove characters that are problematic in filenames
|
||||
for char in ['/', '\\', ':', '*', '?', '"', '<', '>', '|']:
|
||||
video_title = video_title.replace(char, '_')
|
||||
|
||||
# Create temp directory and use video title as filename
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
temp_output_path = os.path.join(temp_dir, video_title)
|
||||
temp_output_path = tempfile.mktemp()
|
||||
wav_file = temp_output_path + ".wav"
|
||||
wav_file = str(Path(wav_file).resolve())
|
||||
|
||||
cookiefile = os.getenv("BUZZ_DOWNLOAD_COOKIEFILE")
|
||||
|
||||
options = {
|
||||
"format": "bestaudio/best",
|
||||
|
|
@ -93,25 +63,12 @@ class FileTranscriber(QObject):
|
|||
"-threads", "0",
|
||||
"-i", temp_output_path,
|
||||
"-ac", "1",
|
||||
"-ar", str(whisper_audio.SAMPLE_RATE),
|
||||
"-ar", str(SAMPLE_RATE),
|
||||
"-acodec", "pcm_s16le",
|
||||
"-loglevel", "panic",
|
||||
wav_file
|
||||
]
|
||||
wav_file]
|
||||
|
||||
if sys.platform == "win32":
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
startupinfo=si,
|
||||
env=app_env,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
|
||||
if len(result.stderr):
|
||||
logging.warning(f"Error processing downloaded audio. Error: {result.stderr.decode()}")
|
||||
|
|
@ -149,22 +106,13 @@ class FileTranscriber(QObject):
|
|||
)
|
||||
|
||||
if self.transcription_task.source == FileTranscriptionTask.Source.FOLDER_WATCH:
|
||||
# Use original_file_path if available (before speech extraction changed file_path)
|
||||
source_path = (
|
||||
self.transcription_task.original_file_path
|
||||
or self.transcription_task.file_path
|
||||
shutil.move(
|
||||
self.transcription_task.file_path,
|
||||
os.path.join(
|
||||
self.transcription_task.output_directory,
|
||||
os.path.basename(self.transcription_task.file_path),
|
||||
),
|
||||
)
|
||||
if source_path and os.path.exists(source_path):
|
||||
if self.transcription_task.delete_source_file:
|
||||
os.remove(source_path)
|
||||
else:
|
||||
shutil.move(
|
||||
source_path,
|
||||
os.path.join(
|
||||
self.transcription_task.output_directory,
|
||||
os.path.basename(source_path),
|
||||
),
|
||||
)
|
||||
|
||||
def on_download_progress(self, data: dict):
|
||||
if data["status"] == "downloading":
|
||||
|
|
@ -179,6 +127,7 @@ class FileTranscriber(QObject):
|
|||
...
|
||||
|
||||
|
||||
# TODO: Move to transcription service
|
||||
def write_output(
|
||||
path: str,
|
||||
segments: List[Segment],
|
||||
|
|
@ -192,15 +141,13 @@ def write_output(
|
|||
len(segments),
|
||||
)
|
||||
|
||||
with open(os.fsencode(path), "w", encoding="utf-8") as file:
|
||||
with open(path, "w", encoding="utf-8") as file:
|
||||
if output_format == OutputFormat.TXT:
|
||||
combined_text = ""
|
||||
previous_end_time = None
|
||||
|
||||
paragraph_split_time = int(os.getenv("BUZZ_PARAGRAPH_SPLIT_TIME", "2000"))
|
||||
|
||||
for segment in segments:
|
||||
if previous_end_time is not None and (segment.start - previous_end_time) >= paragraph_split_time:
|
||||
if previous_end_time is not None and (segment.start - previous_end_time) >= 2000:
|
||||
combined_text += "\n\n"
|
||||
combined_text += getattr(segment, segment_key).strip() + " "
|
||||
previous_end_time = segment.end
|
||||
|
|
@ -234,9 +181,3 @@ def to_timestamp(ms: float, ms_separator=".") -> str:
|
|||
sec = int(ms / 1000)
|
||||
ms = int(ms - sec * 1000)
|
||||
return f"{hr:02d}:{min:02d}:{sec:02d}{ms_separator}{ms:03d}"
|
||||
|
||||
# To detect when transcription source is a video
|
||||
VIDEO_EXTENSIONS = {".mp4", ".mov", ".mkv", ".avi", ".m4v", ".webm", ".ogm", ".wmv"}
|
||||
|
||||
def is_video_file(path: str) -> bool:
|
||||
return Path(path).suffix.lower() in VIDEO_EXTENSIONS
|
||||
|
|
|
|||
|
|
@ -1,94 +0,0 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
import subprocess
|
||||
from typing import Optional, List
|
||||
|
||||
from PyQt6.QtCore import QObject
|
||||
from openai import OpenAI
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.assets import APP_BASE_DIR
|
||||
from buzz.transcriber.openai_whisper_api_file_transcriber import OpenAIWhisperAPIFileTranscriber
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment
|
||||
|
||||
|
||||
# Currently unused, but kept for future reference
|
||||
class LocalWhisperCppServerTranscriber(OpenAIWhisperAPIFileTranscriber):
|
||||
# To be used on Windows only
|
||||
def __init__(self, task: FileTranscriptionTask, parent: Optional["QObject"] = None) -> None:
|
||||
super().__init__(task=task, parent=parent)
|
||||
|
||||
self.process = None
|
||||
self.initialization_error = None
|
||||
cmd = [
|
||||
os.path.join(APP_BASE_DIR, "whisper-server.exe"),
|
||||
"--port", "3000",
|
||||
"--inference-path", "/audio/transcriptions",
|
||||
"--threads", str(os.getenv("BUZZ_WHISPERCPP_N_THREADS", (os.cpu_count() or 8) // 2)),
|
||||
"--model", task.model_path,
|
||||
"--suppress-nst"
|
||||
]
|
||||
|
||||
if task.transcription_options.language is not None:
|
||||
cmd.extend(["--language", task.transcription_options.language])
|
||||
|
||||
logging.debug(f"Starting Whisper server with command: {' '.join(cmd)}")
|
||||
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL, # For debug set to subprocess.PIPE, but it will freeze on Windows after ~30 seconds
|
||||
stderr=subprocess.PIPE,
|
||||
shell=False,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
|
||||
# Wait for server to start and load model
|
||||
time.sleep(10)
|
||||
|
||||
if self.process is not None and self.process.poll() is None:
|
||||
logging.debug(f"Whisper server started successfully.")
|
||||
logging.debug(f"Model: {task.model_path}")
|
||||
else:
|
||||
stderr_output = ""
|
||||
if self.process.stderr is not None:
|
||||
stderr_output = self.process.stderr.read().decode()
|
||||
logging.error(f"Whisper server failed to start. Error: {stderr_output}")
|
||||
self.initialization_error = _("Whisper server failed to start. Check logs for details.")
|
||||
|
||||
if "ErrorOutOfDeviceMemory" in stderr_output:
|
||||
self.initialization_error = _("Whisper server failed to start due to insufficient memory. "
|
||||
"Please try again with a smaller model. "
|
||||
"To force CPU mode use BUZZ_FORCE_CPU=TRUE environment variable.")
|
||||
return
|
||||
|
||||
self.openai_client = OpenAI(
|
||||
api_key="not-used",
|
||||
base_url="http://127.0.0.1:3000",
|
||||
max_retries=0
|
||||
)
|
||||
|
||||
def transcribe(self) -> List[Segment]:
|
||||
if self.initialization_error:
|
||||
raise Exception(self.initialization_error)
|
||||
|
||||
return super().transcribe()
|
||||
|
||||
def stop(self):
|
||||
if self.process and self.process.poll() is None:
|
||||
try:
|
||||
self.process.terminate()
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
# Force kill if terminate doesn't work within 5 seconds
|
||||
logging.warning("Whisper server didn't terminate gracefully, force killing")
|
||||
self.process.kill()
|
||||
try:
|
||||
self.process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logging.error("Failed to kill whisper server process")
|
||||
except Exception as e:
|
||||
logging.error(f"Error stopping whisper server: {e}")
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
|
|
@ -1,40 +1,19 @@
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from PyQt6.QtCore import QObject
|
||||
from openai import OpenAI
|
||||
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.transcriber.file_transcriber import FileTranscriber, app_env
|
||||
from buzz.model_loader import get_custom_api_whisper_model
|
||||
from buzz.transcriber.file_transcriber import FileTranscriber
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment, Task
|
||||
|
||||
|
||||
def append_segment(result, txt: bytes, start: int, end: int):
|
||||
if txt == b'':
|
||||
return True
|
||||
|
||||
# try-catch will guard against multi-byte utf-8 characters
|
||||
# https://github.com/ggerganov/whisper.cpp/issues/1798
|
||||
try:
|
||||
result.append(
|
||||
Segment(
|
||||
start=start * 10, # centisecond to ms
|
||||
end=end * 10, # centisecond to ms
|
||||
text=txt.decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
|
||||
class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
||||
def __init__(self, task: FileTranscriptionTask, parent: Optional["QObject"] = None):
|
||||
super().__init__(task=task, parent=parent)
|
||||
|
|
@ -45,13 +24,9 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
self.task = task.transcription_options.task
|
||||
self.openai_client = OpenAI(
|
||||
api_key=self.transcription_task.transcription_options.openai_access_token,
|
||||
base_url=custom_openai_base_url if custom_openai_base_url else None,
|
||||
max_retries=0
|
||||
base_url=custom_openai_base_url if custom_openai_base_url else None
|
||||
)
|
||||
self.whisper_api_model = settings.value(
|
||||
key=Settings.Key.OPENAI_API_MODEL, default_value="whisper-1"
|
||||
)
|
||||
self.word_level_timings = self.transcription_task.transcription_options.word_level_timings
|
||||
self.whisper_api_model = get_custom_api_whisper_model(custom_openai_base_url)
|
||||
logging.debug("Will use whisper API on %s, %s",
|
||||
custom_openai_base_url, self.whisper_api_model)
|
||||
|
||||
|
|
@ -63,7 +38,6 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
)
|
||||
|
||||
mp3_file = tempfile.mktemp() + ".mp3"
|
||||
mp3_file = str(Path(mp3_file).resolve())
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
|
|
@ -72,19 +46,7 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
"-i", self.transcription_task.file_path, mp3_file
|
||||
]
|
||||
|
||||
if sys.platform == "win32":
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
startupinfo=si,
|
||||
env=app_env,
|
||||
creationflags = subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
logging.warning(f"FFMPEG audio load warning. Process return code was not zero: {result.returncode}")
|
||||
|
|
@ -101,27 +63,10 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||
mp3_file,
|
||||
]
|
||||
|
||||
# fmt: on
|
||||
if sys.platform == "win32":
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
|
||||
duration_secs = float(
|
||||
subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
startupinfo=si,
|
||||
env=app_env,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
).stdout.decode("utf-8"),
|
||||
)
|
||||
else:
|
||||
duration_secs = float(
|
||||
subprocess.run(cmd, capture_output=True, check=True).stdout.decode("utf-8")
|
||||
)
|
||||
duration_secs = float(
|
||||
subprocess.run(cmd, capture_output=True, check=True).stdout.decode("utf-8")
|
||||
)
|
||||
|
||||
total_size = os.path.getsize(mp3_file)
|
||||
max_chunk_size = 25 * 1024 * 1024
|
||||
|
|
@ -143,7 +88,6 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
chunk_end = min((i + 1) * chunk_duration, duration_secs)
|
||||
|
||||
chunk_file = tempfile.mktemp() + ".mp3"
|
||||
chunk_file = str(Path(chunk_file).resolve())
|
||||
|
||||
# fmt: off
|
||||
cmd = [
|
||||
|
|
@ -155,21 +99,7 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
chunk_file,
|
||||
]
|
||||
# fmt: on
|
||||
if sys.platform == "win32":
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
startupinfo=si,
|
||||
env=app_env,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
else:
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
logging.debug('Created chunk file "%s"', chunk_file)
|
||||
|
||||
segments.extend(
|
||||
|
|
@ -182,29 +112,14 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
|
||||
return segments
|
||||
|
||||
@staticmethod
|
||||
def get_value(segment, key, default=None):
|
||||
if hasattr(segment, key):
|
||||
return getattr(segment, key)
|
||||
if isinstance(segment, dict):
|
||||
return segment.get(key, default)
|
||||
return default
|
||||
|
||||
def get_segments_for_file(self, file: str, offset_ms: int = 0):
|
||||
with open(file, "rb") as file:
|
||||
# gpt-4o models don't support verbose_json format
|
||||
response_format = "json" if self.whisper_api_model.startswith("gpt-4o") else "verbose_json"
|
||||
|
||||
options = {
|
||||
"model": self.whisper_api_model,
|
||||
"file": file,
|
||||
"response_format": response_format,
|
||||
"response_format": "verbose_json",
|
||||
"prompt": self.transcription_task.transcription_options.initial_prompt,
|
||||
}
|
||||
|
||||
if self.word_level_timings:
|
||||
options["timestamp_granularities"] = ["word"]
|
||||
|
||||
transcript = (
|
||||
self.openai_client.audio.transcriptions.create(
|
||||
**options,
|
||||
|
|
@ -214,80 +129,14 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
else self.openai_client.audio.translations.create(**options)
|
||||
)
|
||||
|
||||
segments = getattr(transcript, "segments", None)
|
||||
|
||||
words = getattr(transcript, "words", None)
|
||||
if words is None and "words" in transcript.model_extra:
|
||||
words = transcript.model_extra["words"]
|
||||
|
||||
if segments is None:
|
||||
if "segments" in transcript.model_extra:
|
||||
segments = transcript.model_extra["segments"]
|
||||
else:
|
||||
# gpt-4o models return only text without segments/timestamps
|
||||
segments = [{"text": transcript.text, "start": 0, "end": 0, "words": words}]
|
||||
|
||||
result_segments = []
|
||||
if self.word_level_timings:
|
||||
|
||||
# Detect response from whisper.cpp API
|
||||
first_segment = segments[0] if segments else None
|
||||
is_whisper_cpp = (first_segment and hasattr(first_segment, "tokens")
|
||||
and hasattr(first_segment, "avg_logprob") and hasattr(first_segment, "no_speech_prob"))
|
||||
|
||||
if is_whisper_cpp:
|
||||
txt_buffer = b''
|
||||
txt_start = 0
|
||||
txt_end = 0
|
||||
|
||||
for segment in segments:
|
||||
for word in self.get_value(segment, "words"):
|
||||
|
||||
txt = self.get_value(word, "word").encode("utf-8")
|
||||
start = self.get_value(word, "start")
|
||||
end = self.get_value(word, "end")
|
||||
|
||||
if txt.startswith(b' ') and append_segment(result_segments, txt_buffer, txt_start, txt_end):
|
||||
txt_buffer = txt
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
continue
|
||||
|
||||
if txt.startswith(b', '):
|
||||
txt_buffer += b','
|
||||
append_segment(result_segments, txt_buffer, txt_start, txt_end)
|
||||
txt_buffer = txt.lstrip(b',')
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
continue
|
||||
|
||||
txt_buffer += txt
|
||||
txt_end = end
|
||||
|
||||
# Append the last segment
|
||||
append_segment(result_segments, txt_buffer, txt_start, txt_end)
|
||||
|
||||
else:
|
||||
for segment in segments:
|
||||
for word in self.get_value(segment, "words"):
|
||||
result_segments.append(
|
||||
Segment(
|
||||
int(self.get_value(word, "start") * 1000 + offset_ms),
|
||||
int(self.get_value(word, "end") * 1000 + offset_ms),
|
||||
self.get_value(word, "word"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
result_segments = [
|
||||
Segment(
|
||||
int(self.get_value(segment, "start", 0) * 1000 + offset_ms),
|
||||
int(self.get_value(segment, "end", 0) * 1000 + offset_ms),
|
||||
self.get_value(segment, "text", ""),
|
||||
)
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
return result_segments
|
||||
return [
|
||||
Segment(
|
||||
int(segment["start"] * 1000 + offset_ms),
|
||||
int(segment["end"] * 1000 + offset_ms),
|
||||
segment["text"],
|
||||
)
|
||||
for segment in transcript.model_extra["segments"]
|
||||
]
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -2,18 +2,13 @@ import datetime
|
|||
import logging
|
||||
import platform
|
||||
import os
|
||||
import sys
|
||||
import wave
|
||||
import time
|
||||
import tempfile
|
||||
import threading
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
# Preload CUDA libraries before importing torch
|
||||
from buzz import cuda_setup # noqa: F401
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import sounddevice
|
||||
|
|
@ -22,12 +17,11 @@ from openai import OpenAI
|
|||
from PyQt6.QtCore import QObject, pyqtSignal
|
||||
|
||||
from buzz import whisper_audio
|
||||
from buzz.locale import _
|
||||
from buzz.assets import APP_BASE_DIR
|
||||
from buzz.model_loader import ModelType, map_language_to_mms
|
||||
from buzz.model_loader import WhisperModelSize, ModelType, get_custom_api_whisper_model
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.transcriber.transcriber import TranscriptionOptions, Task, DEFAULT_WHISPER_TEMPERATURE
|
||||
from buzz.transformers_whisper import TransformersTranscriber
|
||||
from buzz.transcriber.transcriber import TranscriptionOptions, Task
|
||||
from buzz.transcriber.whisper_cpp import WhisperCpp
|
||||
from buzz.transformers_whisper import TransformersWhisper
|
||||
from buzz.settings.recording_transcriber_mode import RecordingTranscriberMode
|
||||
|
||||
import whisper
|
||||
|
|
@ -38,9 +32,6 @@ class RecordingTranscriber(QObject):
|
|||
transcription = pyqtSignal(str)
|
||||
finished = pyqtSignal()
|
||||
error = pyqtSignal(str)
|
||||
amplitude_changed = pyqtSignal(float)
|
||||
average_amplitude_changed = pyqtSignal(float)
|
||||
queue_size_changed = pyqtSignal(int)
|
||||
is_running = False
|
||||
SAMPLE_RATE = whisper_audio.SAMPLE_RATE
|
||||
|
||||
|
|
@ -62,10 +53,10 @@ class RecordingTranscriber(QObject):
|
|||
self.input_device_index = input_device_index
|
||||
self.sample_rate = sample_rate if sample_rate is not None else whisper_audio.SAMPLE_RATE
|
||||
self.model_path = model_path
|
||||
self.n_batch_samples = int(5 * self.sample_rate) # 5 seconds
|
||||
self.n_batch_samples = 5 * self.sample_rate # 5 seconds
|
||||
self.keep_sample_seconds = 0.15
|
||||
if self.transcriber_mode == RecordingTranscriberMode.APPEND_AND_CORRECT:
|
||||
self.n_batch_samples = int(transcription_options.transcription_step * self.sample_rate)
|
||||
self.n_batch_samples = 3 * self.sample_rate # 3 seconds
|
||||
self.keep_sample_seconds = 1.5
|
||||
# pause queueing if more than 3 batches behind
|
||||
self.max_queue_size = 3 * self.n_batch_samples
|
||||
|
|
@ -73,80 +64,58 @@ class RecordingTranscriber(QObject):
|
|||
self.mutex = threading.Lock()
|
||||
self.sounddevice = sounddevice
|
||||
self.openai_client = None
|
||||
self.whisper_api_model = self.settings.value(
|
||||
key=Settings.Key.OPENAI_API_MODEL, default_value="whisper-1"
|
||||
)
|
||||
self.process = None
|
||||
self._stderr_lines: list[bytes] = []
|
||||
self.whisper_api_model = get_custom_api_whisper_model("")
|
||||
|
||||
def start(self):
|
||||
self.is_running = True
|
||||
model = None
|
||||
model_path = self.model_path
|
||||
keep_samples = int(self.keep_sample_seconds * self.sample_rate)
|
||||
|
||||
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
|
||||
use_cuda = torch.cuda.is_available() and force_cpu == "false"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
logging.debug(f"CUDA version detected: {torch.version.cuda}")
|
||||
|
||||
if self.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
device = "cuda" if use_cuda else "cpu"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = whisper.load_model(model_path, device=device)
|
||||
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
|
||||
self.start_local_whisper_server()
|
||||
if self.openai_client is None:
|
||||
if not self.is_running:
|
||||
self.finished.emit()
|
||||
else:
|
||||
self.error.emit(_("Whisper server failed to start. Check logs for details."))
|
||||
return
|
||||
model = WhisperCpp(model_path)
|
||||
elif self.transcription_options.model.model_type == ModelType.FASTER_WHISPER:
|
||||
model_root_dir = user_cache_dir("Buzz")
|
||||
model_root_dir = os.path.join(model_root_dir, "models")
|
||||
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
|
||||
|
||||
device = "auto"
|
||||
if platform.system() == "Windows":
|
||||
logging.debug("CUDA GPUs are currently no supported on Running on Windows, using CPU")
|
||||
device = "cpu"
|
||||
|
||||
if torch.cuda.is_available() and torch.version.cuda < "12":
|
||||
logging.debug("Unsupported CUDA version (<12), using CPU")
|
||||
device = "cpu"
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
logging.debug("CUDA is not available, using CPU")
|
||||
device = "cpu"
|
||||
|
||||
if force_cpu != "false":
|
||||
device = "cpu"
|
||||
|
||||
# Check if user wants reduced GPU memory usage (int8 quantization)
|
||||
reduce_gpu_memory = os.getenv("BUZZ_REDUCE_GPU_MEMORY", "false") != "false"
|
||||
compute_type = "default"
|
||||
if reduce_gpu_memory:
|
||||
compute_type = "int8" if device == "cpu" else "int8_float16"
|
||||
logging.debug(f"Using {compute_type} compute type for reduced memory usage")
|
||||
|
||||
model = faster_whisper.WhisperModel(
|
||||
model_size_or_path=model_path,
|
||||
download_root=model_root_dir,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
cpu_threads=(os.cpu_count() or 8)//2,
|
||||
)
|
||||
|
||||
# Fix for large-v3 https://github.com/guillaumekln/faster-whisper/issues/547#issuecomment-1797962599
|
||||
if self.transcription_options.model.whisper_model_size == WhisperModelSize.LARGEV3:
|
||||
model.feature_extractor.mel_filters = model.feature_extractor.get_mel_filters(
|
||||
model.feature_extractor.sampling_rate, model.feature_extractor.n_fft, n_mels=128
|
||||
)
|
||||
elif self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
custom_openai_base_url = self.settings.value(
|
||||
key=Settings.Key.CUSTOM_OPENAI_BASE_URL, default_value=""
|
||||
)
|
||||
self.whisper_api_model = get_custom_api_whisper_model(custom_openai_base_url)
|
||||
self.openai_client = OpenAI(
|
||||
api_key=self.transcription_options.openai_access_token,
|
||||
base_url=custom_openai_base_url if custom_openai_base_url else None,
|
||||
max_retries=0
|
||||
base_url=custom_openai_base_url if custom_openai_base_url else None
|
||||
)
|
||||
logging.debug("Will use whisper API on %s, %s",
|
||||
custom_openai_base_url, self.whisper_api_model)
|
||||
else: # ModelType.HUGGING_FACE
|
||||
model = TransformersTranscriber(model_path)
|
||||
model = TransformersWhisper(model_path)
|
||||
|
||||
initial_prompt = self.transcription_options.initial_prompt
|
||||
|
||||
|
|
@ -158,6 +127,7 @@ class RecordingTranscriber(QObject):
|
|||
self.input_device_index,
|
||||
)
|
||||
|
||||
self.is_running = True
|
||||
try:
|
||||
with self.sounddevice.InputStream(
|
||||
samplerate=self.sample_rate,
|
||||
|
|
@ -169,33 +139,20 @@ class RecordingTranscriber(QObject):
|
|||
while self.is_running:
|
||||
if self.queue.size >= self.n_batch_samples:
|
||||
self.mutex.acquire()
|
||||
cut = self.find_silence_cut_point(
|
||||
self.queue[:self.n_batch_samples], self.sample_rate
|
||||
)
|
||||
samples = self.queue[:cut]
|
||||
if self.transcriber_mode == RecordingTranscriberMode.APPEND_AND_CORRECT:
|
||||
self.queue = self.queue[cut - keep_samples:]
|
||||
else:
|
||||
self.queue = self.queue[cut:]
|
||||
samples = self.queue[: self.n_batch_samples]
|
||||
self.queue = self.queue[self.n_batch_samples - keep_samples:]
|
||||
self.mutex.release()
|
||||
|
||||
amplitude = self.amplitude(samples)
|
||||
self.average_amplitude_changed.emit(amplitude)
|
||||
self.queue_size_changed.emit(self.queue.size)
|
||||
|
||||
logging.debug(
|
||||
"Processing next frame, sample size = %s, queue size = %s, amplitude = %s",
|
||||
samples.size,
|
||||
self.queue.size,
|
||||
amplitude,
|
||||
self.amplitude(samples),
|
||||
)
|
||||
|
||||
if amplitude < self.transcription_options.silence_threshold:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
time_started = datetime.datetime.now()
|
||||
|
||||
# TODO Filter out silent audio
|
||||
|
||||
if (
|
||||
self.transcription_options.model.model_type
|
||||
== ModelType.WHISPER
|
||||
|
|
@ -206,9 +163,18 @@ class RecordingTranscriber(QObject):
|
|||
language=self.transcription_options.language,
|
||||
task=self.transcription_options.task.value,
|
||||
initial_prompt=initial_prompt,
|
||||
temperature=DEFAULT_WHISPER_TEMPERATURE,
|
||||
no_speech_threshold=0.4,
|
||||
fp16=False,
|
||||
temperature=self.transcription_options.temperature,
|
||||
)
|
||||
elif (
|
||||
self.transcription_options.model.model_type
|
||||
== ModelType.WHISPER_CPP
|
||||
):
|
||||
assert isinstance(model, WhisperCpp)
|
||||
result = model.transcribe(
|
||||
audio=samples,
|
||||
params=model.get_params(
|
||||
transcription_options=self.transcription_options
|
||||
),
|
||||
)
|
||||
elif (
|
||||
self.transcription_options.model.model_type
|
||||
|
|
@ -221,43 +187,25 @@ class RecordingTranscriber(QObject):
|
|||
if self.transcription_options.language != ""
|
||||
else None,
|
||||
task=self.transcription_options.task.value,
|
||||
# Prevent crash on Windows https://github.com/SYSTRAN/faster-whisper/issues/71#issuecomment-1526263764
|
||||
temperature=0 if platform.system() == "Windows" else DEFAULT_WHISPER_TEMPERATURE,
|
||||
temperature=self.transcription_options.temperature,
|
||||
initial_prompt=self.transcription_options.initial_prompt,
|
||||
word_timestamps=False,
|
||||
without_timestamps=True,
|
||||
no_speech_threshold=0.4,
|
||||
word_timestamps=self.transcription_options.word_level_timings,
|
||||
)
|
||||
result = {"text": " ".join([segment.text for segment in whisper_segments])}
|
||||
elif (
|
||||
self.transcription_options.model.model_type
|
||||
== ModelType.HUGGING_FACE
|
||||
):
|
||||
assert isinstance(model, TransformersTranscriber)
|
||||
# Handle MMS-specific language and task
|
||||
if model.is_mms_model:
|
||||
language = map_language_to_mms(
|
||||
self.transcription_options.language or "eng"
|
||||
)
|
||||
effective_task = Task.TRANSCRIBE.value
|
||||
else:
|
||||
language = (
|
||||
self.transcription_options.language
|
||||
if self.transcription_options.language is not None
|
||||
else "en"
|
||||
)
|
||||
effective_task = self.transcription_options.task.value
|
||||
|
||||
assert isinstance(model, TransformersWhisper)
|
||||
result = model.transcribe(
|
||||
audio=samples,
|
||||
language=language,
|
||||
task=effective_task,
|
||||
language=self.transcription_options.language
|
||||
if self.transcription_options.language is not None
|
||||
else "en",
|
||||
task=self.transcription_options.task.value,
|
||||
)
|
||||
else: # OPEN_AI_WHISPER_API, also used for WHISPER_CPP
|
||||
if self.openai_client is None:
|
||||
self.error.emit(_("A connection error occurred"))
|
||||
return
|
||||
|
||||
else: # OPEN_AI_WHISPER_API
|
||||
assert self.openai_client is not None
|
||||
# scale samples to 16-bit PCM
|
||||
pcm_data = (samples * 32767).astype(np.int16).tobytes()
|
||||
|
||||
|
|
@ -274,7 +222,7 @@ class RecordingTranscriber(QObject):
|
|||
options = {
|
||||
"model": self.whisper_api_model,
|
||||
"file": temp_file,
|
||||
"response_format": "json",
|
||||
"response_format": "verbose_json",
|
||||
"prompt": self.transcription_options.initial_prompt,
|
||||
}
|
||||
|
||||
|
|
@ -288,24 +236,17 @@ class RecordingTranscriber(QObject):
|
|||
else self.openai_client.audio.translations.create(**options)
|
||||
)
|
||||
|
||||
if "segments" in transcript.model_extra:
|
||||
result = {"text": " ".join(
|
||||
[segment["text"] for segment in transcript.model_extra["segments"]])}
|
||||
else:
|
||||
result = {"text": transcript.text}
|
||||
|
||||
result = {"text": " ".join(
|
||||
[segment["text"] for segment in transcript.model_extra["segments"]])}
|
||||
except Exception as e:
|
||||
if self.is_running:
|
||||
result = {"text": f"Error: {str(e)}"}
|
||||
else:
|
||||
result = {"text": ""}
|
||||
result = {"text": f"Error: {str(e)}"}
|
||||
|
||||
os.unlink(temp_filename)
|
||||
|
||||
next_text: str = result.get("text")
|
||||
|
||||
# Update initial prompt between successive recording chunks
|
||||
initial_prompt = next_text
|
||||
initial_prompt += next_text
|
||||
|
||||
logging.debug(
|
||||
"Received next result, length = %s, time taken = %s",
|
||||
|
|
@ -318,19 +259,8 @@ class RecordingTranscriber(QObject):
|
|||
|
||||
except PortAudioError as exc:
|
||||
self.error.emit(str(exc))
|
||||
logging.exception("PortAudio error during recording")
|
||||
logging.exception("")
|
||||
return
|
||||
except Exception as exc:
|
||||
logging.exception("Unexpected error during recording")
|
||||
self.error.emit(str(exc))
|
||||
return
|
||||
|
||||
# Cleanup before emitting finished to avoid destroying QThread
|
||||
# while this function is still on the call stack
|
||||
if model:
|
||||
del model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.finished.emit()
|
||||
|
||||
|
|
@ -353,172 +283,13 @@ class RecordingTranscriber(QObject):
|
|||
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
|
||||
# Try to enqueue the next block. If the queue is already full, drop the block.
|
||||
chunk: np.ndarray = in_data.ravel()
|
||||
|
||||
amplitude = self.amplitude(chunk)
|
||||
self.amplitude_changed.emit(amplitude)
|
||||
|
||||
with self.mutex:
|
||||
if self.queue.size < self.max_queue_size:
|
||||
self.queue = np.append(self.queue, chunk)
|
||||
|
||||
@staticmethod
|
||||
def find_silence_cut_point(samples: np.ndarray, sample_rate: int,
|
||||
search_seconds: float = 1.5,
|
||||
window_seconds: float = 0.02,
|
||||
silence_ratio: float = 0.5) -> int:
|
||||
"""Return index of the last quiet point in the final search_seconds of samples.
|
||||
|
||||
Scans backwards through short windows; returns the midpoint of the rightmost
|
||||
window whose RMS is below silence_ratio * mean_rms of the search region.
|
||||
Falls back to len(samples) if no quiet window is found.
|
||||
"""
|
||||
window = int(window_seconds * sample_rate)
|
||||
search_start = max(0, len(samples) - int(search_seconds * sample_rate))
|
||||
region = samples[search_start:]
|
||||
n_windows = (len(region) - window) // window
|
||||
if n_windows < 1:
|
||||
return len(samples)
|
||||
|
||||
energies = np.array([
|
||||
np.sqrt(np.mean(region[i * window:(i + 1) * window] ** 2))
|
||||
for i in range(n_windows)
|
||||
])
|
||||
mean_energy = energies.mean()
|
||||
threshold = silence_ratio * mean_energy
|
||||
|
||||
for i in range(n_windows - 1, -1, -1):
|
||||
if energies[i] < threshold:
|
||||
cut = search_start + i * window + window // 2
|
||||
return cut
|
||||
|
||||
return len(samples)
|
||||
|
||||
@staticmethod
|
||||
def amplitude(arr: np.ndarray):
|
||||
return float(np.sqrt(np.mean(arr**2)))
|
||||
|
||||
def _drain_stderr(self):
|
||||
if self.process and self.process.stderr:
|
||||
for line in self.process.stderr:
|
||||
self._stderr_lines.append(line)
|
||||
return (abs(max(arr)) + abs(min(arr))) / 2
|
||||
|
||||
def stop_recording(self):
|
||||
self.is_running = False
|
||||
if self.process and self.process.poll() is None:
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
logging.warning("Whisper server process had to be killed after timeout")
|
||||
|
||||
def start_local_whisper_server(self):
|
||||
# Reduce verbose HTTP client logging from OpenAI/httpx
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||
|
||||
self.transcription.emit(_("Starting Whisper.cpp..."))
|
||||
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
self.transcription.emit(_("First time use of a model may take up to several minutest to load."))
|
||||
|
||||
self.process = None
|
||||
|
||||
server_executable = "whisper-server.exe" if sys.platform == "win32" else "whisper-server"
|
||||
server_path = os.path.join(APP_BASE_DIR, "whisper_cpp", server_executable)
|
||||
|
||||
# If running Mac and Windows installed version
|
||||
if not os.path.exists(server_path):
|
||||
server_path = os.path.join(APP_BASE_DIR, "buzz", "whisper_cpp", server_executable)
|
||||
|
||||
cmd = [
|
||||
server_path,
|
||||
"--port", "3003",
|
||||
"--inference-path", "/audio/transcriptions",
|
||||
"--threads", str(os.getenv("BUZZ_WHISPERCPP_N_THREADS", (os.cpu_count() or 8) // 2)),
|
||||
"--model", self.model_path,
|
||||
"--no-timestamps",
|
||||
# Protections against hallucinated repetition. Seems to be problem on macOS
|
||||
# https://github.com/ggml-org/whisper.cpp/issues/1507
|
||||
"--max-context", "64",
|
||||
"--entropy-thold", "2.8",
|
||||
"--suppress-nst"
|
||||
]
|
||||
|
||||
if self.transcription_options.language is not None:
|
||||
cmd.extend(["--language", self.transcription_options.language])
|
||||
else:
|
||||
cmd.extend(["--language", "auto"])
|
||||
|
||||
logging.debug(f"Starting Whisper server with command: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=False,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
else:
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=False,
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to start whisper-server subprocess: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
return
|
||||
|
||||
# Drain stderr in a background thread to prevent pipe buffer from filling
|
||||
# up and blocking the subprocess (especially on Windows with compiled exe).
|
||||
self._stderr_lines = []
|
||||
stderr_thread = threading.Thread(target=self._drain_stderr, daemon=True)
|
||||
stderr_thread.start()
|
||||
|
||||
# Wait for server to start and load model, checking periodically
|
||||
for i in range(100): # 10 seconds total, in 0.1s increments
|
||||
if not self.is_running or self.process.poll() is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
if self.process is not None and self.process.poll() is None:
|
||||
self.transcription.emit(_("Starting transcription..."))
|
||||
logging.debug(f"Whisper server started successfully.")
|
||||
logging.debug(f"Model: {self.model_path}")
|
||||
else:
|
||||
stderr_thread.join(timeout=2)
|
||||
stderr_output = b"".join(self._stderr_lines).decode(errors="replace")
|
||||
logging.error(f"Whisper server failed to start. Error: {stderr_output}")
|
||||
|
||||
self.transcription.emit(_("Whisper server failed to start. Check logs for details."))
|
||||
|
||||
if "ErrorOutOfDeviceMemory" in stderr_output:
|
||||
message = _(
|
||||
"Whisper server failed to start due to insufficient memory. "
|
||||
"Please try again with a smaller model. "
|
||||
"To force CPU mode use BUZZ_FORCE_CPU=TRUE environment variable."
|
||||
)
|
||||
logging.error(message)
|
||||
self.transcription.emit(message)
|
||||
|
||||
return
|
||||
|
||||
self.openai_client = OpenAI(
|
||||
api_key="not-used",
|
||||
base_url="http://127.0.0.1:3003",
|
||||
timeout=30.0,
|
||||
max_retries=0
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if self.process and self.process.poll() is None:
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
|
|
@ -21,7 +21,7 @@ class Task(enum.Enum):
|
|||
|
||||
|
||||
TASK_LABEL_TRANSLATIONS = {
|
||||
Task.TRANSLATE: _("Translate to English"),
|
||||
Task.TRANSLATE: _("Translate"),
|
||||
Task.TRANSCRIBE: _("Transcribe"),
|
||||
}
|
||||
|
||||
|
|
@ -35,106 +35,106 @@ class Segment:
|
|||
|
||||
|
||||
LANGUAGES = {
|
||||
"en": _("English"),
|
||||
"zh": _("Chinese"),
|
||||
"de": _("German"),
|
||||
"es": _("Spanish"),
|
||||
"ru": _("Russian"),
|
||||
"ko": _("Korean"),
|
||||
"fr": _("French"),
|
||||
"ja": _("Japanese"),
|
||||
"pt": _("Portuguese"),
|
||||
"tr": _("Turkish"),
|
||||
"pl": _("Polish"),
|
||||
"ca": _("Catalan"),
|
||||
"nl": _("Dutch"),
|
||||
"ar": _("Arabic"),
|
||||
"sv": _("Swedish"),
|
||||
"it": _("Italian"),
|
||||
"id": _("Indonesian"),
|
||||
"hi": _("Hindi"),
|
||||
"fi": _("Finnish"),
|
||||
"vi": _("Vietnamese"),
|
||||
"he": _("Hebrew"),
|
||||
"uk": _("Ukrainian"),
|
||||
"el": _("Greek"),
|
||||
"ms": _("Malay"),
|
||||
"cs": _("Czech"),
|
||||
"ro": _("Romanian"),
|
||||
"da": _("Danish"),
|
||||
"hu": _("Hungarian"),
|
||||
"ta": _("Tamil"),
|
||||
"no": _("Norwegian"),
|
||||
"th": _("Thai"),
|
||||
"ur": _("Urdu"),
|
||||
"hr": _("Croatian"),
|
||||
"bg": _("Bulgarian"),
|
||||
"lt": _("Lithuanian"),
|
||||
"la": _("Latin"),
|
||||
"mi": _("Maori"),
|
||||
"ml": _("Malayalam"),
|
||||
"cy": _("Welsh"),
|
||||
"sk": _("Slovak"),
|
||||
"te": _("Telugu"),
|
||||
"fa": _("Persian"),
|
||||
"lv": _("Latvian"),
|
||||
"bn": _("Bengali"),
|
||||
"sr": _("Serbian"),
|
||||
"az": _("Azerbaijani"),
|
||||
"sl": _("Slovenian"),
|
||||
"kn": _("Kannada"),
|
||||
"et": _("Estonian"),
|
||||
"mk": _("Macedonian"),
|
||||
"br": _("Breton"),
|
||||
"eu": _("Basque"),
|
||||
"is": _("Icelandic"),
|
||||
"hy": _("Armenian"),
|
||||
"ne": _("Nepali"),
|
||||
"mn": _("Mongolian"),
|
||||
"bs": _("Bosnian"),
|
||||
"kk": _("Kazakh"),
|
||||
"sq": _("Albanian"),
|
||||
"sw": _("Swahili"),
|
||||
"gl": _("Galician"),
|
||||
"mr": _("Marathi"),
|
||||
"pa": _("Punjabi"),
|
||||
"si": _("Sinhala"),
|
||||
"km": _("Khmer"),
|
||||
"sn": _("Shona"),
|
||||
"yo": _("Yoruba"),
|
||||
"so": _("Somali"),
|
||||
"af": _("Afrikaans"),
|
||||
"oc": _("Occitan"),
|
||||
"ka": _("Georgian"),
|
||||
"be": _("Belarusian"),
|
||||
"tg": _("Tajik"),
|
||||
"sd": _("Sindhi"),
|
||||
"gu": _("Gujarati"),
|
||||
"am": _("Amharic"),
|
||||
"yi": _("Yiddish"),
|
||||
"lo": _("Lao"),
|
||||
"uz": _("Uzbek"),
|
||||
"fo": _("Faroese"),
|
||||
"ht": _("Haitian Creole"),
|
||||
"ps": _("Pashto"),
|
||||
"tk": _("Turkmen"),
|
||||
"nn": _("Nynorsk"),
|
||||
"mt": _("Maltese"),
|
||||
"sa": _("Sanskrit"),
|
||||
"lb": _("Luxembourgish"),
|
||||
"my": _("Myanmar"),
|
||||
"bo": _("Tibetan"),
|
||||
"tl": _("Tagalog"),
|
||||
"mg": _("Malagasy"),
|
||||
"as": _("Assamese"),
|
||||
"tt": _("Tatar"),
|
||||
"haw": _("Hawaiian"),
|
||||
"ln": _("Lingala"),
|
||||
"ha": _("Hausa"),
|
||||
"ba": _("Bashkir"),
|
||||
"jw": _("Javanese"),
|
||||
"su": _("Sundanese"),
|
||||
"yue": _("Cantonese"),
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -144,7 +144,6 @@ class TranscriptionOptions:
|
|||
task: Task = Task.TRANSCRIBE
|
||||
model: TranscriptionModel = field(default_factory=TranscriptionModel)
|
||||
word_level_timings: bool = False
|
||||
extract_speech: bool = False
|
||||
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
|
||||
initial_prompt: str = ""
|
||||
openai_access_token: str = field(
|
||||
|
|
@ -153,9 +152,6 @@ class TranscriptionOptions:
|
|||
enable_llm_translation: bool = False
|
||||
llm_prompt: str = ""
|
||||
llm_model: str = ""
|
||||
silence_threshold: float = 0.0025
|
||||
line_separator: str = "\n\n"
|
||||
transcription_step: float = 3.5
|
||||
|
||||
|
||||
def humanize_language(language: str) -> str:
|
||||
|
|
@ -202,8 +198,6 @@ class FileTranscriptionTask:
|
|||
output_directory: Optional[str] = None
|
||||
source: Source = Source.FILE_IMPORT
|
||||
file_path: Optional[str] = None
|
||||
original_file_path: Optional[str] = None # Original path before speech extraction
|
||||
delete_source_file: bool = False
|
||||
url: Optional[str] = None
|
||||
fraction_downloaded: float = 0.0
|
||||
|
||||
|
|
@ -218,10 +212,8 @@ class Stopped(Exception):
|
|||
pass
|
||||
|
||||
|
||||
SUPPORTED_AUDIO_FORMATS = "Media files (*.mp3 *.wav *.m4a *.ogg *.opus *.flac *.mp4 *.webm *.ogm *.mov *.mkv *.avi *.wmv);;\
|
||||
Audio files (*.mp3 *.wav *.m4a *.ogg *.opus *.flac);;\
|
||||
Video files (*.mp4 *.webm *.ogm *.mov *.mkv *.avi *.wmv);;\
|
||||
All files (*.*)"
|
||||
SUPPORTED_AUDIO_FORMATS = "Audio files (*.mp3 *.wav *.m4a *.ogg *.opus *.flac);;\
|
||||
Video files (*.mp4 *.webm *.ogm *.mov *.mkv *.avi *.wmv);;All files (*.*)"
|
||||
|
||||
|
||||
def get_output_file_path(
|
||||
|
|
@ -234,9 +226,6 @@ def get_output_file_path(
|
|||
export_file_name_template: str | None = None,
|
||||
):
|
||||
input_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
# Remove "_speech" suffix from extracted speech files
|
||||
if input_file_name.endswith("_speech"):
|
||||
input_file_name = input_file_name[:-7]
|
||||
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")
|
||||
|
||||
export_file_name_template = (
|
||||
|
|
|
|||
|
|
@ -1,383 +1,253 @@
|
|||
import platform
|
||||
import os
|
||||
import sys
|
||||
import ctypes
|
||||
import logging
|
||||
import subprocess
|
||||
import json
|
||||
from typing import List
|
||||
from buzz.assets import APP_BASE_DIR
|
||||
from buzz.transcriber.transcriber import Segment, Task, FileTranscriptionTask
|
||||
from buzz.transcriber.file_transcriber import app_env
|
||||
from typing import Union, Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from buzz import whisper_audio
|
||||
from buzz.model_loader import LOADED_WHISPER_CPP_BINARY
|
||||
from buzz.transcriber.transcriber import Segment, Task, TranscriptionOptions
|
||||
|
||||
if LOADED_WHISPER_CPP_BINARY:
|
||||
from buzz import whisper_cpp
|
||||
|
||||
|
||||
IS_VULKAN_SUPPORTED = False
|
||||
try:
|
||||
import vulkan
|
||||
IS_COREML_SUPPORTED = False
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
try:
|
||||
from buzz import whisper_cpp_coreml # noqa: F401
|
||||
|
||||
instance = vulkan.vkCreateInstance(vulkan.VkInstanceCreateInfo(), None)
|
||||
vulkan.vkDestroyInstance(instance, None)
|
||||
vulkan_version = vulkan.vkEnumerateInstanceVersion()
|
||||
major = (vulkan_version >> 22) & 0x3FF
|
||||
minor = (vulkan_version >> 12) & 0x3FF
|
||||
|
||||
logging.debug("Vulkan version = %s.%s", major, minor)
|
||||
|
||||
# On macOS, default whisper_cpp is compiled with CoreML (Apple Silicon) or Vulkan (Intel).
|
||||
if platform.system() in ("Linux", "Windows") and ((major > 1) or (major == 1 and minor >= 2)):
|
||||
IS_VULKAN_SUPPORTED = True
|
||||
|
||||
except (ImportError, Exception) as e:
|
||||
logging.debug(f"Vulkan import error: {e}")
|
||||
|
||||
IS_VULKAN_SUPPORTED = False
|
||||
IS_COREML_SUPPORTED = True
|
||||
except ImportError:
|
||||
logging.exception("")
|
||||
|
||||
|
||||
class WhisperCpp:
|
||||
@staticmethod
|
||||
def transcribe(task: FileTranscriptionTask) -> List[Segment]:
|
||||
"""Transcribe audio using whisper-cli subprocess."""
|
||||
cli_executable = "whisper-cli.exe" if sys.platform == "win32" else "whisper-cli"
|
||||
whisper_cli_path = os.path.join(APP_BASE_DIR, "whisper_cpp", cli_executable)
|
||||
def __init__(self, model: str) -> None:
|
||||
|
||||
# If running Mac and Windows installed version
|
||||
if not os.path.exists(whisper_cli_path):
|
||||
whisper_cli_path = os.path.join(APP_BASE_DIR, "buzz", "whisper_cpp", cli_executable)
|
||||
self.is_coreml_supported = IS_COREML_SUPPORTED
|
||||
|
||||
language = (
|
||||
task.transcription_options.language
|
||||
if task.transcription_options.language is not None
|
||||
else "en"
|
||||
)
|
||||
if self.is_coreml_supported:
|
||||
coreml_model = model.replace(".bin", "-encoder.mlmodelc")
|
||||
if not os.path.exists(coreml_model):
|
||||
self.is_coreml_supported = False
|
||||
|
||||
# Check if file format is supported, convert to WAV if not
|
||||
supported_formats = ('.mp3', '.wav', '.flac')
|
||||
file_ext = os.path.splitext(task.file_path)[1].lower()
|
||||
logging.debug(f"WhisperCpp model {model}, (Core ML: {self.is_coreml_supported})")
|
||||
|
||||
temp_file = None
|
||||
file_to_process = task.file_path
|
||||
self.instance = self.get_instance()
|
||||
self.ctx = self.instance.init_from_file(model)
|
||||
self.segments: List[Segment] = []
|
||||
|
||||
if file_ext not in supported_formats:
|
||||
temp_file = task.file_path + ".wav"
|
||||
def append_segment(self, txt: bytes, start: int, end: int):
|
||||
if txt == b'':
|
||||
return True
|
||||
|
||||
logging.info(f"Converting {task.file_path} to WAV format")
|
||||
|
||||
# Convert using ffmpeg
|
||||
ffmpeg_cmd = [
|
||||
"ffmpeg",
|
||||
"-i", task.file_path,
|
||||
"-ar", "16000", # 16kHz sample rate (whisper standard)
|
||||
"-ac", "1", # mono
|
||||
"-y", # overwrite output file
|
||||
temp_file
|
||||
]
|
||||
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
result = subprocess.run(
|
||||
ffmpeg_cmd,
|
||||
capture_output=True,
|
||||
startupinfo=si,
|
||||
env=app_env,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW,
|
||||
check = True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(ffmpeg_cmd, capture_output=True, check=True)
|
||||
|
||||
file_to_process = temp_file
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise Exception(f"Failed to convert audio file: {e.stderr.decode()}")
|
||||
except FileNotFoundError:
|
||||
raise Exception("ffmpeg not found. Please install ffmpeg to process this audio format.")
|
||||
|
||||
# Build the command
|
||||
cmd = [
|
||||
whisper_cli_path,
|
||||
"--model", task.model_path,
|
||||
"--language", language,
|
||||
"--print-progress",
|
||||
"--suppress-nst",
|
||||
# Protections against hallucinated repetition. Seems to be problem on macOS
|
||||
# https://github.com/ggml-org/whisper.cpp/issues/1507
|
||||
"--max-context", "64",
|
||||
"--entropy-thold", "2.8",
|
||||
"--output-json-full",
|
||||
"--threads", str(os.getenv("BUZZ_WHISPERCPP_N_THREADS", (os.cpu_count() or 8) // 2)),
|
||||
"-f", file_to_process,
|
||||
]
|
||||
|
||||
# Add VAD if the model is available
|
||||
vad_model_path = os.path.join(os.path.dirname(whisper_cli_path), "ggml-silero-v6.2.0.bin")
|
||||
if os.path.exists(vad_model_path):
|
||||
cmd.extend(["--vad", "--vad-model", vad_model_path])
|
||||
|
||||
# Add translate flag if needed
|
||||
if task.transcription_options.task == Task.TRANSLATE:
|
||||
cmd.extend(["--translate"])
|
||||
|
||||
# Force CPU if specified
|
||||
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
|
||||
if force_cpu != "false" or (not IS_VULKAN_SUPPORTED and platform.system() != "Darwin"):
|
||||
cmd.extend(["--no-gpu"])
|
||||
|
||||
print(f"Running Whisper CLI: {' '.join(cmd)}")
|
||||
|
||||
# Run the whisper-cli process
|
||||
if sys.platform == "win32":
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
startupinfo=si,
|
||||
env=app_env,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# Capture stderr for progress updates
|
||||
stderr_output = []
|
||||
while True:
|
||||
line = process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
stderr_output.append(line.strip())
|
||||
# Progress is written to stderr
|
||||
sys.stderr.write(line)
|
||||
|
||||
process.wait()
|
||||
|
||||
if process.returncode != 0:
|
||||
# Clean up temp file if conversion was done
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
print(f"Failed to remove temporary file {temp_file}: {e}")
|
||||
raise Exception(f"whisper-cli failed with return code {process.returncode}")
|
||||
|
||||
# Find and read the generated JSON file
|
||||
# whisper-cli generates: input_file.ext.json (e.g., file.mp3.json)
|
||||
json_output_path = f"{file_to_process}.json"
|
||||
|
||||
# try-catch will guard against multi-byte utf-8 characters
|
||||
# https://github.com/ggerganov/whisper.cpp/issues/1798
|
||||
try:
|
||||
# Read JSON with latin-1 to preserve raw bytes, then handle encoding per field
|
||||
# This is needed because whisper-cli can write invalid UTF-8 sequences for multi-byte characters
|
||||
with open(json_output_path, 'r', encoding='latin-1') as f:
|
||||
result = json.load(f)
|
||||
|
||||
segments = []
|
||||
|
||||
# Handle word-level timings
|
||||
if task.transcription_options.word_level_timings:
|
||||
# Extract word-level timestamps from tokens array
|
||||
# Combine tokens into words using similar logic as whisper_cpp.py
|
||||
transcription = result.get("transcription", [])
|
||||
self.segments.append(
|
||||
Segment(
|
||||
start=start * 10, # centisecond to ms
|
||||
end=end * 10, # centisecond to ms
|
||||
text=txt.decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
# Languages that don't use spaces between words
|
||||
# For these, each token is treated as a separate word
|
||||
non_space_languages = {"zh", "ja", "th", "lo", "km", "my"}
|
||||
is_non_space_language = language in non_space_languages
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
|
||||
for segment_data in transcription:
|
||||
tokens = segment_data.get("tokens", [])
|
||||
def transcribe(self, audio: Union[np.ndarray, str], params: Any):
|
||||
self.segments = []
|
||||
|
||||
if is_non_space_language:
|
||||
# For languages without spaces (Chinese, Japanese, etc.),
|
||||
# each complete UTF-8 character is treated as a separate word.
|
||||
# Some characters may be split across multiple tokens as raw bytes.
|
||||
char_buffer = b""
|
||||
char_start = 0
|
||||
char_end = 0
|
||||
if isinstance(audio, str):
|
||||
audio = whisper_audio.load_audio(audio)
|
||||
|
||||
def flush_complete_chars(buffer: bytes, start: int, end: int):
|
||||
"""Extract and output all complete UTF-8 characters from buffer.
|
||||
Returns any remaining incomplete bytes."""
|
||||
nonlocal segments
|
||||
remaining = buffer
|
||||
pos = 0
|
||||
logging.debug("Loaded audio with length = %s", len(audio))
|
||||
|
||||
while pos < len(remaining):
|
||||
# Try to decode one character at a time
|
||||
for char_len in range(1, min(5, len(remaining) - pos + 1)):
|
||||
try:
|
||||
char = remaining[pos:pos + char_len].decode("utf-8")
|
||||
# Successfully decoded a character
|
||||
if char.strip():
|
||||
segments.append(
|
||||
Segment(
|
||||
start=start,
|
||||
end=end,
|
||||
text=char,
|
||||
translation=""
|
||||
)
|
||||
)
|
||||
pos += char_len
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
if char_len == 4 or pos + char_len >= len(remaining):
|
||||
# Incomplete character at end - return as remaining
|
||||
return remaining[pos:]
|
||||
else:
|
||||
# Couldn't decode, might be incomplete at end
|
||||
return remaining[pos:]
|
||||
whisper_cpp_audio = audio.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
result = self.instance.full(
|
||||
self.ctx, params, whisper_cpp_audio, len(audio)
|
||||
)
|
||||
if result != 0:
|
||||
raise Exception(f"Error from whisper.cpp: {result}")
|
||||
|
||||
return b""
|
||||
n_segments = self.instance.full_n_segments(self.ctx)
|
||||
|
||||
for token_data in tokens:
|
||||
token_text = token_data.get("text", "")
|
||||
if params.token_timestamps:
|
||||
# Will process word timestamps
|
||||
txt_buffer = b''
|
||||
txt_start = 0
|
||||
txt_end = 0
|
||||
|
||||
# Skip special tokens like [_TT_], [_BEG_]
|
||||
if token_text.startswith("[_"):
|
||||
continue
|
||||
for i in range(n_segments):
|
||||
txt = self.instance.full_get_segment_text(self.ctx, i)
|
||||
start = self.instance.full_get_segment_t0(self.ctx, i)
|
||||
end = self.instance.full_get_segment_t1(self.ctx, i)
|
||||
|
||||
if not token_text:
|
||||
continue
|
||||
if txt.startswith(b' ') and self.append_segment(txt_buffer, txt_start, txt_end):
|
||||
txt_buffer = txt
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
continue
|
||||
|
||||
token_start = int(token_data.get("offsets", {}).get("from", 0))
|
||||
token_end = int(token_data.get("offsets", {}).get("to", 0))
|
||||
if txt.startswith(b', '):
|
||||
txt_buffer += b','
|
||||
self.append_segment(txt_buffer, txt_start, txt_end)
|
||||
txt_buffer = txt.lstrip(b',')
|
||||
txt_start = start
|
||||
txt_end = end
|
||||
continue
|
||||
|
||||
# Convert latin-1 string back to original bytes
|
||||
token_bytes = token_text.encode("latin-1")
|
||||
txt_buffer += txt
|
||||
txt_end = end
|
||||
|
||||
if not char_buffer:
|
||||
char_start = token_start
|
||||
# Append the last segment
|
||||
self.append_segment(txt_buffer, txt_start, txt_end)
|
||||
|
||||
char_buffer += token_bytes
|
||||
char_end = token_end
|
||||
else:
|
||||
for i in range(n_segments):
|
||||
txt = self.instance.full_get_segment_text(self.ctx, i)
|
||||
start = self.instance.full_get_segment_t0(self.ctx, i)
|
||||
end = self.instance.full_get_segment_t1(self.ctx, i)
|
||||
|
||||
# Try to flush complete characters
|
||||
char_buffer = flush_complete_chars(char_buffer, char_start, char_end)
|
||||
self.append_segment(txt, start, end)
|
||||
|
||||
# If buffer was fully flushed, reset start time for next char
|
||||
if not char_buffer:
|
||||
char_start = token_end
|
||||
return {
|
||||
"segments": self.segments,
|
||||
"text": "".join([segment.text for segment in self.segments]),
|
||||
}
|
||||
|
||||
# Flush any remaining buffer at end of segment
|
||||
if char_buffer:
|
||||
flush_complete_chars(char_buffer, char_start, char_end)
|
||||
else:
|
||||
# For space-separated languages, accumulate tokens into words
|
||||
word_buffer = b""
|
||||
word_start = 0
|
||||
word_end = 0
|
||||
def get_instance(self):
|
||||
if self.is_coreml_supported:
|
||||
return WhisperCppCoreML()
|
||||
return WhisperCppCpu()
|
||||
|
||||
def append_word(buffer: bytes, start: int, end: int):
|
||||
"""Try to decode and append a word segment, handling multi-byte UTF-8"""
|
||||
if not buffer:
|
||||
return True
|
||||
def get_params(
|
||||
self,
|
||||
transcription_options: TranscriptionOptions,
|
||||
print_realtime=False,
|
||||
print_progress=False,
|
||||
):
|
||||
params = self.instance.full_default_params(whisper_cpp.WHISPER_SAMPLING_GREEDY)
|
||||
params.n_threads = int(os.getenv("BUZZ_WHISPERCPP_N_THREADS", 4))
|
||||
params.print_realtime = print_realtime
|
||||
params.print_progress = print_progress
|
||||
params.language = self.instance.get_string((transcription_options.language or "en"))
|
||||
params.translate = transcription_options.task == Task.TRANSLATE
|
||||
params.max_len = ctypes.c_int(1)
|
||||
params.max_len = 1 if transcription_options.word_level_timings else 0
|
||||
params.token_timestamps = transcription_options.word_level_timings
|
||||
params.initial_prompt = self.instance.get_string(transcription_options.initial_prompt)
|
||||
return params
|
||||
|
||||
# Try to decode as UTF-8
|
||||
# https://github.com/ggerganov/whisper.cpp/issues/1798
|
||||
try:
|
||||
text = buffer.decode("utf-8").strip()
|
||||
if text:
|
||||
segments.append(
|
||||
Segment(
|
||||
start=start,
|
||||
end=end,
|
||||
text=text,
|
||||
translation=""
|
||||
)
|
||||
)
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
# Multi-byte character is split, continue accumulating
|
||||
return False
|
||||
def __del__(self):
|
||||
if self.instance:
|
||||
self.instance.free(self.ctx)
|
||||
|
||||
for token_data in tokens:
|
||||
# Token text is read as latin-1, need to convert to bytes to get original data
|
||||
token_text = token_data.get("text", "")
|
||||
|
||||
# Skip special tokens like [_TT_], [_BEG_]
|
||||
if token_text.startswith("[_"):
|
||||
continue
|
||||
class WhisperCppInterface:
|
||||
def full_default_params(self, sampling: int):
|
||||
raise NotImplementedError
|
||||
|
||||
if not token_text:
|
||||
continue
|
||||
def get_string(self, string: str):
|
||||
raise NotImplementedError
|
||||
|
||||
# Skip low probability tokens
|
||||
token_p = token_data.get("p", 1.0)
|
||||
if token_p < 0.01:
|
||||
continue
|
||||
def get_encoder_begin_callback(self, callback):
|
||||
raise NotImplementedError
|
||||
|
||||
token_start = int(token_data.get("offsets", {}).get("from", 0))
|
||||
token_end = int(token_data.get("offsets", {}).get("to", 0))
|
||||
def get_new_segment_callback(self, callback):
|
||||
raise NotImplementedError
|
||||
|
||||
# Convert latin-1 string back to original bytes
|
||||
# (latin-1 preserves byte values as code points)
|
||||
token_bytes = token_text.encode("latin-1")
|
||||
def init_from_file(self, model: str):
|
||||
raise NotImplementedError
|
||||
|
||||
# Check if token starts with space - indicates new word
|
||||
if token_bytes.startswith(b" ") and word_buffer:
|
||||
# Save previous word
|
||||
append_word(word_buffer, word_start, word_end)
|
||||
# Start new word
|
||||
word_buffer = token_bytes
|
||||
word_start = token_start
|
||||
word_end = token_end
|
||||
elif token_bytes.startswith(b", "):
|
||||
# Handle comma - save word with comma, then start new word
|
||||
word_buffer += b","
|
||||
append_word(word_buffer, word_start, word_end)
|
||||
word_buffer = token_bytes.lstrip(b",")
|
||||
word_start = token_start
|
||||
word_end = token_end
|
||||
else:
|
||||
# Accumulate token into current word
|
||||
if not word_buffer:
|
||||
word_start = token_start
|
||||
word_buffer += token_bytes
|
||||
word_end = token_end
|
||||
def full(self, ctx, params, audio, length):
|
||||
raise NotImplementedError
|
||||
|
||||
# Add the last word
|
||||
append_word(word_buffer, word_start, word_end)
|
||||
else:
|
||||
# Use segment-level timestamps
|
||||
transcription = result.get("transcription", [])
|
||||
for segment_data in transcription:
|
||||
# Segment text is also read as latin-1, convert back to UTF-8
|
||||
segment_text_latin1 = segment_data.get("text", "")
|
||||
try:
|
||||
# Convert latin-1 string to bytes, then decode as UTF-8
|
||||
segment_text = segment_text_latin1.encode("latin-1").decode("utf-8").strip()
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# If conversion fails, use the original text
|
||||
segment_text = segment_text_latin1.strip()
|
||||
|
||||
segments.append(
|
||||
Segment(
|
||||
start=int(segment_data.get("offsets", {}).get("from", 0)),
|
||||
end=int(segment_data.get("offsets", {}).get("to", 0)),
|
||||
text=segment_text,
|
||||
translation=""
|
||||
)
|
||||
)
|
||||
|
||||
return segments
|
||||
finally:
|
||||
# Clean up the generated JSON file
|
||||
if os.path.exists(json_output_path):
|
||||
try:
|
||||
os.remove(json_output_path)
|
||||
except Exception as e:
|
||||
print(f"Failed to remove JSON output file {json_output_path}: {e}")
|
||||
def full_n_segments(self, ctx):
|
||||
raise NotImplementedError
|
||||
|
||||
# Clean up temporary audio file if conversion was done
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
print(f"Failed to remove temporary file {temp_file}: {e}")
|
||||
def full_get_segment_text(self, ctx, i):
|
||||
raise NotImplementedError
|
||||
|
||||
def full_get_segment_t0(self, ctx, i):
|
||||
raise NotImplementedError
|
||||
|
||||
def full_get_segment_t1(self, ctx, i):
|
||||
raise NotImplementedError
|
||||
|
||||
def free(self, ctx):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WhisperCppCpu(WhisperCppInterface):
|
||||
def full_default_params(self, sampling: int):
|
||||
return whisper_cpp.whisper_full_default_params(sampling)
|
||||
|
||||
def get_string(self, string: str):
|
||||
return whisper_cpp.String(string.encode())
|
||||
|
||||
def get_encoder_begin_callback(self, callback):
|
||||
return whisper_cpp.whisper_encoder_begin_callback(callback)
|
||||
|
||||
def get_new_segment_callback(self, callback):
|
||||
return whisper_cpp.whisper_new_segment_callback(callback)
|
||||
|
||||
def init_from_file(self, model: str):
|
||||
return whisper_cpp.whisper_init_from_file(model.encode())
|
||||
|
||||
def full(self, ctx, params, audio, length):
|
||||
return whisper_cpp.whisper_full(ctx, params, audio, length)
|
||||
|
||||
def full_n_segments(self, ctx):
|
||||
return whisper_cpp.whisper_full_n_segments(ctx)
|
||||
|
||||
def full_get_segment_text(self, ctx, i):
|
||||
return whisper_cpp.whisper_full_get_segment_text(ctx, i)
|
||||
|
||||
def full_get_segment_t0(self, ctx, i):
|
||||
return whisper_cpp.whisper_full_get_segment_t0(ctx, i)
|
||||
|
||||
def full_get_segment_t1(self, ctx, i):
|
||||
return whisper_cpp.whisper_full_get_segment_t1(ctx, i)
|
||||
|
||||
def free(self, ctx):
|
||||
return whisper_cpp.whisper_free(ctx)
|
||||
|
||||
|
||||
class WhisperCppCoreML(WhisperCppInterface):
|
||||
def full_default_params(self, sampling: int):
|
||||
return whisper_cpp_coreml.whisper_full_default_params(sampling)
|
||||
|
||||
def get_string(self, string: str):
|
||||
return whisper_cpp_coreml.String(string.encode())
|
||||
|
||||
def get_encoder_begin_callback(self, callback):
|
||||
return whisper_cpp_coreml.whisper_encoder_begin_callback(callback)
|
||||
|
||||
def get_new_segment_callback(self, callback):
|
||||
return whisper_cpp_coreml.whisper_new_segment_callback(callback)
|
||||
|
||||
def init_from_file(self, model: str):
|
||||
return whisper_cpp_coreml.whisper_init_from_file(model.encode())
|
||||
|
||||
def full(self, ctx, params, audio, length):
|
||||
return whisper_cpp_coreml.whisper_full(ctx, params, audio, length)
|
||||
|
||||
def full_n_segments(self, ctx):
|
||||
return whisper_cpp_coreml.whisper_full_n_segments(ctx)
|
||||
|
||||
def full_get_segment_text(self, ctx, i):
|
||||
return whisper_cpp_coreml.whisper_full_get_segment_text(ctx, i)
|
||||
|
||||
def full_get_segment_t0(self, ctx, i):
|
||||
return whisper_cpp_coreml.whisper_full_get_segment_t0(ctx, i)
|
||||
|
||||
def full_get_segment_t1(self, ctx, i):
|
||||
return whisper_cpp_coreml.whisper_full_get_segment_t1(ctx, i)
|
||||
|
||||
def free(self, ctx):
|
||||
return whisper_cpp_coreml.whisper_free(ctx)
|
||||
|
|
|
|||
91
buzz/transcriber/whisper_cpp_file_transcriber.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
import ctypes
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional, List
|
||||
|
||||
from PyQt6.QtCore import QObject
|
||||
|
||||
from buzz import whisper_audio
|
||||
from buzz.transcriber.file_transcriber import FileTranscriber
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment, Stopped
|
||||
from buzz.transcriber.whisper_cpp import WhisperCpp
|
||||
|
||||
|
||||
class WhisperCppFileTranscriber(FileTranscriber):
|
||||
duration_audio_ms = sys.maxsize # max int
|
||||
state: "WhisperCppFileTranscriber.State"
|
||||
|
||||
class State:
|
||||
running = True
|
||||
|
||||
def __init__(
|
||||
self, task: FileTranscriptionTask, parent: Optional["QObject"] = None
|
||||
) -> None:
|
||||
super().__init__(task, parent)
|
||||
|
||||
self.transcription_options = task.transcription_options
|
||||
self.model_path = task.model_path
|
||||
self.model = WhisperCpp(model=self.model_path)
|
||||
self.state = self.State()
|
||||
|
||||
def transcribe(self) -> List[Segment]:
|
||||
self.state.running = True
|
||||
|
||||
logging.debug(
|
||||
"Starting whisper_cpp file transcription, file path = %s, language = %s, "
|
||||
"task = %s, model_path = %s, word level timings = %s",
|
||||
self.transcription_task.file_path,
|
||||
self.transcription_options.language,
|
||||
self.transcription_options.task,
|
||||
self.model_path,
|
||||
self.transcription_options.word_level_timings,
|
||||
)
|
||||
|
||||
audio = whisper_audio.load_audio(self.transcription_task.file_path)
|
||||
self.duration_audio_ms = len(audio) * 1000 / whisper_audio.SAMPLE_RATE
|
||||
|
||||
whisper_params = self.model.get_params(
|
||||
transcription_options=self.transcription_options
|
||||
)
|
||||
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(
|
||||
id(self.state)
|
||||
)
|
||||
whisper_params.encoder_begin_callback = (
|
||||
self.model.get_instance().get_encoder_begin_callback(self.encoder_begin_callback)
|
||||
)
|
||||
whisper_params.new_segment_callback_user_data = ctypes.c_void_p(id(self.state))
|
||||
whisper_params.new_segment_callback = self.model.get_instance().get_new_segment_callback(
|
||||
self.new_segment_callback
|
||||
)
|
||||
|
||||
result = self.model.transcribe(
|
||||
audio=self.transcription_task.file_path, params=whisper_params
|
||||
)
|
||||
|
||||
if not self.state.running:
|
||||
raise Stopped
|
||||
|
||||
self.state.running = False
|
||||
return result["segments"]
|
||||
|
||||
def new_segment_callback(self, ctx, _state, _n_new, user_data):
|
||||
n_segments = self.model.get_instance().full_n_segments(ctx)
|
||||
t1 = self.model.get_instance().full_get_segment_t1(ctx, n_segments - 1)
|
||||
# t1 seems to sometimes be larger than the duration when the
|
||||
# audio ends in silence. Trim to fix the displayed progress.
|
||||
progress = min(t1 * 10, self.duration_audio_ms)
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(
|
||||
user_data, ctypes.py_object
|
||||
).value
|
||||
if state.running:
|
||||
self.progress.emit((progress, self.duration_audio_ms))
|
||||
|
||||
@staticmethod
|
||||
def encoder_begin_callback(_ctx, _state, user_data):
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(
|
||||
user_data, ctypes.py_object
|
||||
).value
|
||||
return state.running == 1
|
||||
|
||||
def stop(self):
|
||||
self.state.running = False
|
||||
|
|
@ -5,13 +5,8 @@ import multiprocessing
|
|||
import re
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Preload CUDA libraries before importing torch - required for subprocess contexts
|
||||
from buzz import cuda_setup # noqa: F401
|
||||
|
||||
import torch
|
||||
import platform
|
||||
import subprocess
|
||||
from platformdirs import user_cache_dir
|
||||
from multiprocessing.connection import Connection
|
||||
from threading import Thread
|
||||
|
|
@ -22,13 +17,11 @@ from PyQt6.QtCore import QObject
|
|||
|
||||
from buzz import whisper_audio
|
||||
from buzz.conn import pipe_stderr
|
||||
from buzz.model_loader import ModelType, WhisperModelSize, map_language_to_mms
|
||||
from buzz.transformers_whisper import TransformersTranscriber
|
||||
from buzz.model_loader import ModelType, WhisperModelSize
|
||||
from buzz.transformers_whisper import TransformersWhisper
|
||||
from buzz.transcriber.file_transcriber import FileTranscriber
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment, Task, DEFAULT_WHISPER_TEMPERATURE
|
||||
from buzz.transcriber.whisper_cpp import WhisperCpp
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment
|
||||
|
||||
import av
|
||||
import faster_whisper
|
||||
import whisper
|
||||
import stable_whisper
|
||||
|
|
@ -37,22 +30,6 @@ from stable_whisper import WhisperResult
|
|||
PROGRESS_REGEX = re.compile(r"\d+(\.\d+)?%")
|
||||
|
||||
|
||||
def check_file_has_audio_stream(file_path: str) -> None:
|
||||
"""Check if a media file has at least one audio stream.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file has no audio streams.
|
||||
"""
|
||||
try:
|
||||
with av.open(file_path) as container:
|
||||
if len(container.streams.audio) == 0:
|
||||
raise ValueError("No audio streams found")
|
||||
except av.error.InvalidDataError as e:
|
||||
raise ValueError(f"Invalid media file: {e}")
|
||||
except av.error.FileNotFoundError:
|
||||
raise ValueError("File not found")
|
||||
|
||||
|
||||
class WhisperFileTranscriber(FileTranscriber):
|
||||
"""WhisperFileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file
|
||||
using the default program for opening txt files."""
|
||||
|
|
@ -69,9 +46,6 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
self.segments = []
|
||||
self.started_process = False
|
||||
self.stopped = False
|
||||
self.recv_pipe = None
|
||||
self.send_pipe = None
|
||||
self.error_message = None
|
||||
|
||||
def transcribe(self) -> List[Segment]:
|
||||
time_started = datetime.datetime.now()
|
||||
|
|
@ -82,44 +56,24 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
if torch.cuda.is_available():
|
||||
logging.debug(f"CUDA version detected: {torch.version.cuda}")
|
||||
|
||||
self.recv_pipe, self.send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
self.current_process = multiprocessing.Process(
|
||||
target=self.transcribe_whisper, args=(self.send_pipe, self.transcription_task)
|
||||
target=self.transcribe_whisper, args=(send_pipe, self.transcription_task)
|
||||
)
|
||||
if not self.stopped:
|
||||
self.current_process.start()
|
||||
self.started_process = True
|
||||
|
||||
self.read_line_thread = Thread(target=self.read_line, args=(self.recv_pipe,))
|
||||
self.read_line_thread = Thread(target=self.read_line, args=(recv_pipe,))
|
||||
self.read_line_thread.start()
|
||||
|
||||
# Only join the process if it was actually started
|
||||
if self.started_process:
|
||||
self.current_process.join()
|
||||
self.current_process.join()
|
||||
|
||||
# Close the send pipe after process ends to signal read_line thread to stop
|
||||
# This prevents the read thread from blocking on recv() after the process is gone
|
||||
try:
|
||||
if self.send_pipe and not self.send_pipe.closed:
|
||||
self.send_pipe.close()
|
||||
except OSError:
|
||||
pass
|
||||
if self.current_process.exitcode != 0:
|
||||
send_pipe.close()
|
||||
|
||||
# Close the receive pipe to unblock the read_line thread
|
||||
try:
|
||||
if self.recv_pipe and not self.recv_pipe.closed:
|
||||
self.recv_pipe.close()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Join read_line_thread with timeout to prevent hanging
|
||||
if self.read_line_thread and self.read_line_thread.is_alive():
|
||||
self.read_line_thread.join(timeout=3)
|
||||
if self.read_line_thread.is_alive():
|
||||
logging.warning("Read line thread didn't terminate gracefully in transcribe()")
|
||||
|
||||
self.started_process = False
|
||||
self.read_line_thread.join()
|
||||
|
||||
logging.debug(
|
||||
"whisper process completed with code = %s, time taken = %s,"
|
||||
|
|
@ -130,14 +84,7 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
)
|
||||
|
||||
if self.current_process.exitcode != 0:
|
||||
# Check if the process was terminated (likely due to cancellation)
|
||||
# Exit codes 124-128 are often used for termination signals
|
||||
if self.current_process.exitcode in [124, 125, 126, 127, 128, 130, 137, 143]:
|
||||
# Process was likely terminated, treat as cancellation
|
||||
logging.debug("Whisper process was terminated (exit code: %s), treating as cancellation", self.current_process.exitcode)
|
||||
raise Exception("Transcription was canceled")
|
||||
else:
|
||||
raise Exception(self.error_message or "Unknown error")
|
||||
raise Exception("Unknown error")
|
||||
|
||||
return self.segments
|
||||
|
||||
|
|
@ -145,97 +92,39 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
def transcribe_whisper(
|
||||
cls, stderr_conn: Connection, task: FileTranscriptionTask
|
||||
) -> None:
|
||||
# Patch subprocess on Windows to prevent console window flash
|
||||
# This is needed because multiprocessing spawns a new process without the main process patches
|
||||
if sys.platform == "win32":
|
||||
import subprocess
|
||||
_original_run = subprocess.run
|
||||
_original_popen = subprocess.Popen
|
||||
with pipe_stderr(stderr_conn):
|
||||
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
|
||||
sys.stderr.write("0%\n")
|
||||
segments = cls.transcribe_hugging_face(task)
|
||||
sys.stderr.write("100%\n")
|
||||
elif (
|
||||
task.transcription_options.model.model_type == ModelType.FASTER_WHISPER
|
||||
):
|
||||
segments = cls.transcribe_faster_whisper(task)
|
||||
elif task.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
segments = cls.transcribe_openai_whisper(task)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Invalid model type: {task.transcription_options.model.model_type}"
|
||||
)
|
||||
|
||||
def _patched_run(*args, **kwargs):
|
||||
if 'startupinfo' not in kwargs:
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
kwargs['startupinfo'] = si
|
||||
if 'creationflags' not in kwargs:
|
||||
kwargs['creationflags'] = subprocess.CREATE_NO_WINDOW
|
||||
return _original_run(*args, **kwargs)
|
||||
|
||||
class _PatchedPopen(subprocess.Popen):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'startupinfo' not in kwargs:
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
kwargs['startupinfo'] = si
|
||||
if 'creationflags' not in kwargs:
|
||||
kwargs['creationflags'] = subprocess.CREATE_NO_WINDOW
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
subprocess.run = _patched_run
|
||||
subprocess.Popen = _PatchedPopen
|
||||
|
||||
try:
|
||||
# Check if the file has audio streams before processing
|
||||
check_file_has_audio_stream(task.file_path)
|
||||
|
||||
with pipe_stderr(stderr_conn):
|
||||
if task.transcription_options.model.model_type == ModelType.WHISPER_CPP:
|
||||
segments = cls.transcribe_whisper_cpp(task)
|
||||
elif task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
|
||||
sys.stderr.write("0%\n")
|
||||
segments = cls.transcribe_hugging_face(task)
|
||||
sys.stderr.write("100%\n")
|
||||
elif (
|
||||
task.transcription_options.model.model_type == ModelType.FASTER_WHISPER
|
||||
):
|
||||
segments = cls.transcribe_faster_whisper(task)
|
||||
elif task.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
segments = cls.transcribe_openai_whisper(task)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Invalid model type: {task.transcription_options.model.model_type}"
|
||||
)
|
||||
|
||||
segments_json = json.dumps(segments, ensure_ascii=True, default=vars)
|
||||
sys.stderr.write(f"segments = {segments_json}\n")
|
||||
sys.stderr.write(WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + "\n")
|
||||
except Exception as e:
|
||||
# Send error message back to the parent process
|
||||
stderr_conn.send(f"error = {str(e)}\n")
|
||||
stderr_conn.send(WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + "\n")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def transcribe_whisper_cpp(cls, task: FileTranscriptionTask) -> List[Segment]:
|
||||
return WhisperCpp.transcribe(task)
|
||||
segments_json = json.dumps(segments, ensure_ascii=True, default=vars)
|
||||
sys.stderr.write(f"segments = {segments_json}\n")
|
||||
sys.stderr.write(WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + "\n")
|
||||
|
||||
@classmethod
|
||||
def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
|
||||
model = TransformersTranscriber(task.model_path)
|
||||
|
||||
# Handle language - MMS uses ISO 639-3 codes, Whisper uses ISO 639-1
|
||||
if model.is_mms_model:
|
||||
language = map_language_to_mms(task.transcription_options.language or "eng")
|
||||
# MMS only supports transcription, ignore translation task
|
||||
effective_task = Task.TRANSCRIBE.value
|
||||
# MMS doesn't support word-level timestamps
|
||||
word_timestamps = False
|
||||
else:
|
||||
language = (
|
||||
task.transcription_options.language
|
||||
if task.transcription_options.language is not None
|
||||
else "en"
|
||||
)
|
||||
effective_task = task.transcription_options.task.value
|
||||
word_timestamps = task.transcription_options.word_level_timings
|
||||
|
||||
model = TransformersWhisper(task.model_path)
|
||||
language = (
|
||||
task.transcription_options.language
|
||||
if task.transcription_options.language is not None
|
||||
else "en"
|
||||
)
|
||||
result = model.transcribe(
|
||||
audio=task.file_path,
|
||||
language=language,
|
||||
task=effective_task,
|
||||
word_timestamps=word_timestamps,
|
||||
task=task.transcription_options.task.value,
|
||||
word_timestamps=task.transcription_options.word_level_timings,
|
||||
)
|
||||
return [
|
||||
Segment(
|
||||
|
|
@ -251,97 +140,68 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]:
|
||||
if task.transcription_options.model.whisper_model_size == WhisperModelSize.CUSTOM:
|
||||
model_size_or_path = task.transcription_options.model.hugging_face_model_id
|
||||
elif task.transcription_options.model.whisper_model_size == WhisperModelSize.LARGEV3TURBO:
|
||||
model_size_or_path = "deepdml/faster-whisper-large-v3-turbo-ct2"
|
||||
else:
|
||||
model_size_or_path = task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()
|
||||
|
||||
model_root_dir = user_cache_dir("Buzz")
|
||||
model_root_dir = os.path.join(model_root_dir, "models")
|
||||
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
|
||||
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
|
||||
|
||||
device = "auto"
|
||||
if platform.system() == "Windows":
|
||||
logging.debug("CUDA GPUs are currently no supported on Running on Windows, using CPU")
|
||||
device = "cpu"
|
||||
|
||||
if torch.cuda.is_available() and torch.version.cuda < "12":
|
||||
logging.debug("Unsupported CUDA version (<12), using CPU")
|
||||
device = "cpu"
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
logging.debug("CUDA is not available, using CPU")
|
||||
device = "cpu"
|
||||
|
||||
if force_cpu != "false":
|
||||
device = "cpu"
|
||||
|
||||
# Check if user wants reduced GPU memory usage (int8 quantization)
|
||||
reduce_gpu_memory = os.getenv("BUZZ_REDUCE_GPU_MEMORY", "false") != "false"
|
||||
compute_type = "default"
|
||||
if reduce_gpu_memory:
|
||||
compute_type = "int8" if device == "cpu" else "int8_float16"
|
||||
logging.debug(f"Using {compute_type} compute type for reduced memory usage")
|
||||
|
||||
model = faster_whisper.WhisperModel(
|
||||
model_size_or_path=model_size_or_path,
|
||||
download_root=model_root_dir,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
cpu_threads=(os.cpu_count() or 8)//2,
|
||||
)
|
||||
|
||||
batched_model = faster_whisper.BatchedInferencePipeline(model=model)
|
||||
whisper_segments, info = batched_model.transcribe(
|
||||
whisper_segments, info = model.transcribe(
|
||||
audio=task.file_path,
|
||||
language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value,
|
||||
# Prevent crash on Windows https://github.com/SYSTRAN/faster-whisper/issues/71#issuecomment-1526263764
|
||||
temperature = 0 if platform.system() == "Windows" else DEFAULT_WHISPER_TEMPERATURE,
|
||||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt,
|
||||
word_timestamps=task.transcription_options.word_level_timings,
|
||||
no_speech_threshold=0.4,
|
||||
log_progress=True,
|
||||
)
|
||||
segments = []
|
||||
for segment in whisper_segments:
|
||||
# Segment will contain words if word-level timings is True
|
||||
if segment.words:
|
||||
for word in segment.words:
|
||||
with tqdm.tqdm(total=round(info.duration, 2), unit=" seconds") as pbar:
|
||||
for segment in list(whisper_segments):
|
||||
# Segment will contain words if word-level timings is True
|
||||
if segment.words:
|
||||
for word in segment.words:
|
||||
segments.append(
|
||||
Segment(
|
||||
start=int(word.start * 1000),
|
||||
end=int(word.end * 1000),
|
||||
text=word.word,
|
||||
translation=""
|
||||
)
|
||||
)
|
||||
else:
|
||||
segments.append(
|
||||
Segment(
|
||||
start=int(word.start * 1000),
|
||||
end=int(word.end * 1000),
|
||||
text=word.word,
|
||||
start=int(segment.start * 1000),
|
||||
end=int(segment.end * 1000),
|
||||
text=segment.text,
|
||||
translation=""
|
||||
)
|
||||
)
|
||||
else:
|
||||
segments.append(
|
||||
Segment(
|
||||
start=int(segment.start * 1000),
|
||||
end=int(segment.end * 1000),
|
||||
text=segment.text,
|
||||
translation=""
|
||||
)
|
||||
)
|
||||
|
||||
pbar.update(segment.end - segment.start)
|
||||
return segments
|
||||
|
||||
@classmethod
|
||||
def transcribe_openai_whisper(cls, task: FileTranscriptionTask) -> List[Segment]:
|
||||
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
|
||||
use_cuda = torch.cuda.is_available() and force_cpu == "false"
|
||||
|
||||
device = "cuda" if use_cuda else "cpu"
|
||||
|
||||
# Monkeypatch torch.load to use weights_only=False for PyTorch 2.6+
|
||||
# This is required for loading Whisper models with the newer PyTorch versions
|
||||
original_torch_load = torch.load
|
||||
def patched_torch_load(*args, **kwargs):
|
||||
kwargs.setdefault('weights_only', False)
|
||||
return original_torch_load(*args, **kwargs)
|
||||
|
||||
torch.load = patched_torch_load
|
||||
try:
|
||||
model = whisper.load_model(task.model_path, device=device)
|
||||
finally:
|
||||
torch.load = original_torch_load
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = whisper.load_model(task.model_path, device=device)
|
||||
|
||||
if task.transcription_options.word_level_timings:
|
||||
stable_whisper.modify_model(model)
|
||||
|
|
@ -349,10 +209,8 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
audio=whisper_audio.load_audio(task.file_path),
|
||||
language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value,
|
||||
temperature=DEFAULT_WHISPER_TEMPERATURE,
|
||||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt,
|
||||
no_speech_threshold=0.4,
|
||||
fp16=False,
|
||||
)
|
||||
return [
|
||||
Segment(
|
||||
|
|
@ -372,7 +230,6 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt,
|
||||
verbose=False,
|
||||
fp16=False,
|
||||
)
|
||||
segments = result.get("segments")
|
||||
return [
|
||||
|
|
@ -387,46 +244,14 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
|
||||
def stop(self):
|
||||
self.stopped = True
|
||||
|
||||
if self.started_process:
|
||||
self.current_process.terminate()
|
||||
|
||||
if self.read_line_thread and self.read_line_thread.is_alive():
|
||||
self.read_line_thread.join(timeout=5)
|
||||
if self.read_line_thread.is_alive():
|
||||
logging.warning("Read line thread still alive after 5s")
|
||||
|
||||
self.current_process.join(timeout=10)
|
||||
if self.current_process.is_alive():
|
||||
logging.warning("Process didn't terminate gracefully, force killing")
|
||||
self.current_process.kill()
|
||||
self.current_process.join(timeout=5)
|
||||
|
||||
try:
|
||||
if hasattr(self, 'send_pipe') and self.send_pipe:
|
||||
self.send_pipe.close()
|
||||
except Exception as e:
|
||||
logging.debug(f"Error closing send_pipe: {e}")
|
||||
|
||||
try:
|
||||
if hasattr(self, 'recv_pipe') and self.recv_pipe:
|
||||
self.recv_pipe.close()
|
||||
except Exception as e:
|
||||
logging.debug(f"Error closing recv_pipe: {e}")
|
||||
|
||||
def read_line(self, pipe: Connection):
|
||||
while True:
|
||||
try:
|
||||
line = pipe.recv().strip()
|
||||
|
||||
# Uncomment to debug
|
||||
# print(f"*** DEBUG ***: {line}")
|
||||
|
||||
except (EOFError, BrokenPipeError, ConnectionResetError, OSError):
|
||||
# Connection closed, broken, or process crashed (Windows RPC errors raise OSError)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.debug(f"Error reading from pipe: {e}")
|
||||
except EOFError: # Connection closed
|
||||
break
|
||||
|
||||
if line == self.READ_LINE_THREAD_STOP_TOKEN:
|
||||
|
|
@ -444,8 +269,6 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
for segment in segments_dict
|
||||
]
|
||||
self.segments = segments
|
||||
elif line.startswith("error = "):
|
||||
self.error_message = line[8:]
|
||||
else:
|
||||
try:
|
||||
match = PROGRESS_REGEX.search(line)
|
||||
|
|
|
|||
|
|
@ -1,32 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
import logging
|
||||
import platform
|
||||
import numpy as np
|
||||
|
||||
# Preload CUDA libraries before importing torch
|
||||
from buzz import cuda_setup # noqa: F401
|
||||
|
||||
import torch
|
||||
import requests
|
||||
from typing import Union
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, BitsAndBytesConfig
|
||||
from typing import Optional, Union
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline
|
||||
from transformers.pipelines.audio_utils import ffmpeg_read
|
||||
from transformers.pipelines.automatic_speech_recognition import is_torchaudio_available
|
||||
|
||||
from buzz.model_loader import is_mms_model, map_language_to_mms
|
||||
|
||||
|
||||
def is_intel_mac() -> bool:
|
||||
"""Check if running on Intel Mac (x86_64)."""
|
||||
return sys.platform == 'darwin' and platform.machine() == 'x86_64'
|
||||
|
||||
|
||||
def is_peft_model(model_id: str) -> bool:
|
||||
"""Check if model is a PEFT model based on model ID containing '-peft'."""
|
||||
return "-peft" in model_id.lower()
|
||||
|
||||
|
||||
class PipelineWithProgress(AutomaticSpeechRecognitionPipeline): # pragma: no cover
|
||||
# Copy of transformers `AutomaticSpeechRecognitionPipeline.chunk_iter` method with custom progress output
|
||||
|
|
@ -35,8 +17,7 @@ class PipelineWithProgress(AutomaticSpeechRecognitionPipeline): # pragma: no co
|
|||
inputs_len = inputs.shape[0]
|
||||
step = chunk_len - stride_left - stride_right
|
||||
for chunk_start_idx in range(0, inputs_len, step):
|
||||
|
||||
# Buzz will print progress to stderr
|
||||
# Print progress to stderr
|
||||
progress = int((chunk_start_idx / inputs_len) * 100)
|
||||
sys.stderr.write(f"{progress}%\n")
|
||||
|
||||
|
|
@ -46,7 +27,8 @@ class PipelineWithProgress(AutomaticSpeechRecognitionPipeline): # pragma: no co
|
|||
if dtype is not None:
|
||||
processed = processed.to(dtype=dtype)
|
||||
_stride_left = 0 if chunk_start_idx == 0 else stride_left
|
||||
is_last = chunk_end_idx >= inputs_len
|
||||
# all right strides must be full, otherwise it is the last item
|
||||
is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len
|
||||
_stride_right = 0 if is_last else stride_right
|
||||
|
||||
chunk_len = chunk.shape[0]
|
||||
|
|
@ -116,7 +98,7 @@ class PipelineWithProgress(AutomaticSpeechRecognitionPipeline): # pragma: no co
|
|||
# of the original length in the stride so we can cut properly.
|
||||
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
|
||||
if not isinstance(inputs, np.ndarray):
|
||||
raise TypeError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
||||
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
||||
if len(inputs.shape) != 1:
|
||||
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
||||
|
||||
|
|
@ -127,7 +109,7 @@ class PipelineWithProgress(AutomaticSpeechRecognitionPipeline): # pragma: no co
|
|||
if isinstance(stride_length_s, (int, float)):
|
||||
stride_length_s = [stride_length_s, stride_length_s]
|
||||
|
||||
# XXX: Carefully, this variable will not exist in `seq2seq` setting.
|
||||
# XXX: Carefuly, this variable will not exist in `seq2seq` setting.
|
||||
# Currently chunking is not possible at this level for `seq2seq` so
|
||||
# it's ok.
|
||||
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
|
||||
|
|
@ -138,11 +120,11 @@ class PipelineWithProgress(AutomaticSpeechRecognitionPipeline): # pragma: no co
|
|||
if chunk_len < stride_left + stride_right:
|
||||
raise ValueError("Chunk length must be superior to stride length")
|
||||
|
||||
# Buzz use our custom chunk_iter with progress
|
||||
# Will use our custom chunk_iter with progress
|
||||
for item in self.chunk_iter(
|
||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
||||
):
|
||||
yield {**item, **extra}
|
||||
yield item
|
||||
else:
|
||||
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
|
||||
processed = self.feature_extractor(
|
||||
|
|
@ -151,25 +133,12 @@ class PipelineWithProgress(AutomaticSpeechRecognitionPipeline): # pragma: no co
|
|||
truncation=False,
|
||||
padding="longest",
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
else:
|
||||
if self.type == "seq2seq_whisper" and stride is None:
|
||||
processed = self.feature_extractor(
|
||||
inputs,
|
||||
sampling_rate=self.feature_extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
return_token_timestamps=True,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
extra["num_frames"] = processed.pop("num_frames")
|
||||
else:
|
||||
processed = self.feature_extractor(
|
||||
inputs,
|
||||
sampling_rate=self.feature_extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
processed = self.feature_extractor(
|
||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
if self.torch_dtype is not None:
|
||||
processed = processed.to(dtype=self.torch_dtype)
|
||||
if stride is not None:
|
||||
|
|
@ -180,23 +149,11 @@ class PipelineWithProgress(AutomaticSpeechRecognitionPipeline): # pragma: no co
|
|||
yield {"is_last": True, **processed, **extra}
|
||||
|
||||
|
||||
class TransformersTranscriber:
|
||||
"""Unified transcriber for HuggingFace models (Whisper and MMS)."""
|
||||
|
||||
def __init__(self, model_id: str):
|
||||
class TransformersWhisper:
|
||||
def __init__(
|
||||
self, model_id: str
|
||||
):
|
||||
self.model_id = model_id
|
||||
self._is_mms = is_mms_model(model_id)
|
||||
self._is_peft = is_peft_model(model_id)
|
||||
|
||||
@property
|
||||
def is_mms_model(self) -> bool:
|
||||
"""Returns True if this is an MMS model."""
|
||||
return self._is_mms
|
||||
|
||||
@property
|
||||
def is_peft_model(self) -> bool:
|
||||
"""Returns True if this is a PEFT model."""
|
||||
return self._is_peft
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
|
|
@ -205,316 +162,50 @@ class TransformersTranscriber:
|
|||
task: str,
|
||||
word_timestamps: bool = False,
|
||||
):
|
||||
"""Transcribe audio using either Whisper or MMS model."""
|
||||
if self._is_mms:
|
||||
return self._transcribe_mms(audio, language)
|
||||
else:
|
||||
return self._transcribe_whisper(audio, language, task, word_timestamps)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
def _transcribe_whisper(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
language: str,
|
||||
task: str,
|
||||
word_timestamps: bool = False,
|
||||
):
|
||||
"""Transcribe using Whisper model."""
|
||||
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
|
||||
use_cuda = torch.cuda.is_available() and force_cpu == "false"
|
||||
device = "cuda" if use_cuda else "cpu"
|
||||
torch_dtype = torch.float16 if use_cuda else torch.float32
|
||||
use_safetensors = True
|
||||
if os.path.exists(self.model_id):
|
||||
safetensors_files = [f for f in os.listdir(self.model_id) if f.endswith(".safetensors")]
|
||||
use_safetensors = len(safetensors_files) > 0
|
||||
|
||||
# Check if this is a PEFT model
|
||||
if is_peft_model(self.model_id):
|
||||
model, processor, use_8bit = self._load_peft_model(device, torch_dtype)
|
||||
else:
|
||||
use_safetensors = True
|
||||
if os.path.isdir(self.model_id):
|
||||
safetensors_files = [f for f in os.listdir(self.model_id) if f.endswith(".safetensors")]
|
||||
use_safetensors = len(safetensors_files) > 0
|
||||
|
||||
# Check if user wants reduced GPU memory usage (8-bit quantization)
|
||||
# Skip on Intel Macs as bitsandbytes is not available there
|
||||
reduce_gpu_memory = os.getenv("BUZZ_REDUCE_GPU_MEMORY", "false") != "false"
|
||||
use_8bit = False
|
||||
if device == "cuda" and reduce_gpu_memory and not is_intel_mac():
|
||||
try:
|
||||
import bitsandbytes # noqa: F401
|
||||
use_8bit = True
|
||||
print("Using 8-bit quantization for reduced GPU memory usage")
|
||||
except ImportError:
|
||||
print("bitsandbytes not available, using standard precision")
|
||||
|
||||
if use_8bit:
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.model_id,
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto",
|
||||
use_safetensors=use_safetensors
|
||||
)
|
||||
else:
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=use_safetensors
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
model.generation_config.language = language
|
||||
|
||||
processor = AutoProcessor.from_pretrained(self.model_id)
|
||||
|
||||
pipeline_kwargs = {
|
||||
"task": "automatic-speech-recognition",
|
||||
"pipeline_class": PipelineWithProgress,
|
||||
"generate_kwargs": {
|
||||
"language": language,
|
||||
"task": task,
|
||||
"no_repeat_ngram_size": 3,
|
||||
"repetition_penalty": 1.2,
|
||||
},
|
||||
"model": model,
|
||||
"tokenizer": processor.tokenizer,
|
||||
"feature_extractor": processor.feature_extractor,
|
||||
# pipeline has built in chunking, works faster, but we loose progress output
|
||||
# needed for word level timestamps, otherwise there is huge RAM usage on longer audios
|
||||
"chunk_length_s": 30 if word_timestamps else None,
|
||||
"torch_dtype": torch_dtype,
|
||||
"ignore_warning": True, # Ignore warning about chunk_length_s being experimental for seq2seq models
|
||||
}
|
||||
if not use_8bit:
|
||||
pipeline_kwargs["device"] = device
|
||||
pipe = pipeline(**pipeline_kwargs)
|
||||
|
||||
transcript = pipe(
|
||||
audio,
|
||||
return_timestamps="word" if word_timestamps else True,
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
model.generation_config.language = language
|
||||
model.to(device)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(self.model_id)
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
pipeline_class=PipelineWithProgress,
|
||||
generate_kwargs={"language": language, "task": task},
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
chunk_length_s=30,
|
||||
torch_dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transcript = pipe(audio, return_timestamps="word" if word_timestamps else True)
|
||||
|
||||
segments = []
|
||||
for chunk in transcript['chunks']:
|
||||
start, end = chunk['timestamp']
|
||||
text = chunk['text']
|
||||
|
||||
# Last segment may not have an end timestamp
|
||||
if start is None:
|
||||
start = 0
|
||||
if end is None:
|
||||
end = start + 0.1
|
||||
|
||||
if end > start and text.strip() != "":
|
||||
segments.append({
|
||||
"start": 0 if start is None else start,
|
||||
"end": 0 if end is None else end,
|
||||
"text": text.strip(),
|
||||
"translation": ""
|
||||
})
|
||||
segments.append({
|
||||
"start": 0 if start is None else start,
|
||||
"end": 0 if end is None else end,
|
||||
"text": text,
|
||||
"translation": ""
|
||||
})
|
||||
|
||||
return {
|
||||
"text": transcript['text'],
|
||||
"segments": segments,
|
||||
}
|
||||
|
||||
def _load_peft_model(self, device: str, torch_dtype):
|
||||
"""Load a PEFT (Parameter-Efficient Fine-Tuning) model.
|
||||
|
||||
PEFT models require loading the base model first, then applying the adapter.
|
||||
The base model path is extracted from the PEFT config.
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor, use_8bit)
|
||||
"""
|
||||
from peft import PeftModel, PeftConfig
|
||||
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer
|
||||
|
||||
print(f"Loading PEFT model: {self.model_id}")
|
||||
|
||||
# Get the PEFT model ID (handle both local paths and repo IDs)
|
||||
peft_model_id = self._get_peft_repo_id()
|
||||
|
||||
# Load PEFT config to get base model path
|
||||
peft_config = PeftConfig.from_pretrained(peft_model_id)
|
||||
base_model_path = peft_config.base_model_name_or_path
|
||||
print(f"PEFT base model: {base_model_path}")
|
||||
|
||||
# Load the base Whisper model
|
||||
# Use 8-bit quantization on CUDA if user enabled "Reduce GPU RAM" and bitsandbytes is available
|
||||
# Skip on Intel Macs as bitsandbytes is not available there
|
||||
reduce_gpu_memory = os.getenv("BUZZ_REDUCE_GPU_MEMORY", "false") != "false"
|
||||
use_8bit = False
|
||||
if device == "cuda" and reduce_gpu_memory and not is_intel_mac():
|
||||
try:
|
||||
import bitsandbytes # noqa: F401
|
||||
use_8bit = True
|
||||
print("Using 8-bit quantization for reduced GPU memory usage")
|
||||
except ImportError:
|
||||
print("bitsandbytes not available, using standard precision for PEFT model")
|
||||
|
||||
if use_8bit:
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model = WhisperForConditionalGeneration.from_pretrained(
|
||||
base_model_path,
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto"
|
||||
)
|
||||
else:
|
||||
model = WhisperForConditionalGeneration.from_pretrained(
|
||||
base_model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
# Apply the PEFT adapter
|
||||
model = PeftModel.from_pretrained(model, peft_model_id)
|
||||
model.config.use_cache = True
|
||||
|
||||
# Load feature extractor and tokenizer from base model
|
||||
feature_extractor = WhisperFeatureExtractor.from_pretrained(base_model_path)
|
||||
tokenizer = WhisperTokenizer.from_pretrained(base_model_path, task="transcribe")
|
||||
|
||||
# Create a simple processor-like object that the pipeline expects
|
||||
class PeftProcessor:
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
self.feature_extractor = feature_extractor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
processor = PeftProcessor(feature_extractor, tokenizer)
|
||||
|
||||
return model, processor, use_8bit
|
||||
|
||||
def _get_peft_repo_id(self) -> str:
|
||||
"""Extract HuggingFace repo ID from local cache path for PEFT models."""
|
||||
model_id = self.model_id
|
||||
|
||||
# If it's already a repo ID (contains / but not a file path), return as-is
|
||||
if "/" in model_id and not os.path.exists(model_id):
|
||||
return model_id
|
||||
|
||||
# Extract repo ID from cache path
|
||||
if "models--" in model_id:
|
||||
parts = model_id.split("models--")
|
||||
if len(parts) > 1:
|
||||
repo_part = parts[1].split(os.sep + "snapshots")[0]
|
||||
repo_id = repo_part.replace("--", "/", 1)
|
||||
return repo_id
|
||||
|
||||
# Fallback: return as-is
|
||||
return model_id
|
||||
|
||||
def _get_mms_repo_id(self) -> str:
|
||||
"""Extract HuggingFace repo ID from local cache path or return as-is if already a repo ID."""
|
||||
model_id = self.model_id
|
||||
|
||||
# If it's already a repo ID (contains / but not a file path), return as-is
|
||||
if "/" in model_id and not os.path.exists(model_id):
|
||||
return model_id
|
||||
|
||||
# Extract repo ID from cache path like:
|
||||
# Linux: /home/user/.cache/Buzz/models/models--facebook--mms-1b-all/snapshots/xxx
|
||||
# Windows: C:\Users\user\.cache\Buzz\models\models--facebook--mms-1b-all\snapshots\xxx
|
||||
if "models--" in model_id:
|
||||
# Extract the part after "models--" and before "/snapshots" or "\snapshots"
|
||||
parts = model_id.split("models--")
|
||||
if len(parts) > 1:
|
||||
# Split on os.sep to handle both Windows and Unix paths
|
||||
repo_part = parts[1].split(os.sep + "snapshots")[0]
|
||||
# Convert facebook--mms-1b-all to facebook/mms-1b-all
|
||||
repo_id = repo_part.replace("--", "/", 1)
|
||||
return repo_id
|
||||
|
||||
# Fallback: return as-is
|
||||
return model_id
|
||||
|
||||
def _transcribe_mms(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
language: str,
|
||||
):
|
||||
"""Transcribe using MMS (Massively Multilingual Speech) model."""
|
||||
from transformers import Wav2Vec2ForCTC, AutoProcessor as MMSAutoProcessor
|
||||
from transformers.pipelines.audio_utils import ffmpeg_read as mms_ffmpeg_read
|
||||
|
||||
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
|
||||
use_cuda = torch.cuda.is_available() and force_cpu == "false"
|
||||
device = "cuda" if use_cuda else "cpu"
|
||||
|
||||
# Map language code to ISO 639-3 for MMS
|
||||
mms_language = map_language_to_mms(language)
|
||||
print(f"MMS transcription with language: {mms_language} (original: {language})")
|
||||
|
||||
sys.stderr.write("0%\n")
|
||||
|
||||
# Use repo ID for MMS to allow adapter downloads
|
||||
# Local paths don't work for adapter downloads
|
||||
repo_id = self._get_mms_repo_id()
|
||||
print(f"MMS using repo ID: {repo_id} (from model_id: {self.model_id})")
|
||||
|
||||
# Load processor and model with target language
|
||||
# This will download the language adapter if not cached
|
||||
processor = MMSAutoProcessor.from_pretrained(
|
||||
repo_id,
|
||||
target_lang=mms_language
|
||||
)
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained(
|
||||
repo_id,
|
||||
target_lang=mms_language,
|
||||
ignore_mismatched_sizes=True
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
sys.stderr.write("25%\n")
|
||||
|
||||
# Load and process audio
|
||||
if isinstance(audio, str):
|
||||
with open(audio, "rb") as f:
|
||||
audio_data = f.read()
|
||||
audio_array = mms_ffmpeg_read(audio_data, processor.feature_extractor.sampling_rate)
|
||||
else:
|
||||
audio_array = audio
|
||||
|
||||
# Ensure audio is the right sample rate
|
||||
sampling_rate = processor.feature_extractor.sampling_rate
|
||||
|
||||
sys.stderr.write("50%\n")
|
||||
|
||||
# Process audio in chunks for progress reporting
|
||||
inputs = processor(
|
||||
audio_array,
|
||||
sampling_rate=sampling_rate,
|
||||
return_tensors="pt",
|
||||
padding=True
|
||||
)
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
sys.stderr.write("75%\n")
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs).logits
|
||||
|
||||
# Decode
|
||||
ids = torch.argmax(outputs, dim=-1)[0]
|
||||
transcription = processor.decode(ids)
|
||||
|
||||
sys.stderr.write("100%\n")
|
||||
|
||||
# Calculate approximate duration for segment
|
||||
duration = len(audio_array) / sampling_rate if isinstance(audio_array, np.ndarray) else 0
|
||||
|
||||
# Return in same format as Whisper for consistency
|
||||
# MMS doesn't provide word-level timestamps, so we return a single segment
|
||||
return {
|
||||
"text": transcription,
|
||||
"segments": [{
|
||||
"start": 0,
|
||||
"end": duration,
|
||||
"text": transcription.strip(),
|
||||
"translation": ""
|
||||
}] if transcription.strip() else []
|
||||
}
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
TransformersWhisper = TransformersTranscriber
|
||||
|
||||
|
|
|
|||
|
|
@ -1,25 +1,21 @@
|
|||
import os
|
||||
import re
|
||||
import logging
|
||||
import queue
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
from openai import OpenAI, max_retries
|
||||
from typing import Optional
|
||||
from openai import OpenAI
|
||||
from PyQt6.QtCore import QObject, pyqtSignal
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.store.keyring_store import get_password, Key
|
||||
from buzz.transcriber.transcriber import TranscriptionOptions
|
||||
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
|
||||
|
||||
|
||||
BATCH_SIZE = 10
|
||||
|
||||
|
||||
class Translator(QObject):
|
||||
translation = pyqtSignal(str, int)
|
||||
finished = pyqtSignal()
|
||||
is_running = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -52,137 +48,38 @@ class Translator(QObject):
|
|||
)
|
||||
self.openai_client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=custom_openai_base_url if custom_openai_base_url else None,
|
||||
max_retries=0
|
||||
base_url=custom_openai_base_url if custom_openai_base_url else None
|
||||
)
|
||||
|
||||
def _translate_single(self, transcript: str, transcript_id: int) -> Tuple[str, int]:
|
||||
"""Translate a single transcript via the API. Returns (translation, transcript_id)."""
|
||||
try:
|
||||
def start(self):
|
||||
logging.debug("Starting translation queue")
|
||||
|
||||
self.is_running = True
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
transcript, transcript_id = self.queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
completion = self.openai_client.chat.completions.create(
|
||||
model=self.transcription_options.llm_model,
|
||||
messages=[
|
||||
{"role": "system", "content": self.transcription_options.llm_prompt},
|
||||
{"role": "user", "content": transcript}
|
||||
],
|
||||
timeout=60.0,
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
completion = None
|
||||
logging.error(f"Translation error! Server response: {e}")
|
||||
|
||||
if completion and completion.choices and completion.choices[0].message:
|
||||
logging.debug(f"Received translation response: {completion}")
|
||||
return completion.choices[0].message.content, transcript_id
|
||||
else:
|
||||
logging.error(f"Translation error! Server response: {completion}")
|
||||
# Translation error
|
||||
return "", transcript_id
|
||||
|
||||
def _translate_batch(self, items: List[Tuple[str, int]]) -> List[Tuple[str, int]]:
|
||||
"""Translate multiple transcripts in a single API call.
|
||||
Returns list of (translation, transcript_id) in the same order as input."""
|
||||
numbered_parts = []
|
||||
for i, (transcript, _) in enumerate(items, 1):
|
||||
numbered_parts.append(f"[{i}] {transcript}")
|
||||
combined = "\n".join(numbered_parts)
|
||||
|
||||
batch_prompt = (
|
||||
f"{self.transcription_options.llm_prompt}\n\n"
|
||||
f"You will receive {len(items)} numbered texts. "
|
||||
f"Process each one separately according to the instruction above "
|
||||
f"and return them in the exact same numbered format, e.g.:\n"
|
||||
f"[1] processed text\n[2] processed text"
|
||||
)
|
||||
|
||||
try:
|
||||
completion = self.openai_client.chat.completions.create(
|
||||
model=self.transcription_options.llm_model,
|
||||
messages=[
|
||||
{"role": "system", "content": batch_prompt},
|
||||
{"role": "user", "content": combined}
|
||||
],
|
||||
timeout=60.0,
|
||||
)
|
||||
except Exception as e:
|
||||
completion = None
|
||||
logging.error(f"Batch translation error! Server response: {e}")
|
||||
|
||||
if not (completion and completion.choices and completion.choices[0].message):
|
||||
logging.error(f"Batch translation error! Server response: {completion}")
|
||||
# Translation error
|
||||
return [("", tid) for _, tid in items]
|
||||
|
||||
response_text = completion.choices[0].message.content
|
||||
logging.debug(f"Received batch translation response: {response_text}")
|
||||
|
||||
translations = self._parse_batch_response(response_text, len(items))
|
||||
|
||||
results = []
|
||||
for i, (_, transcript_id) in enumerate(items):
|
||||
if i < len(translations):
|
||||
results.append((translations[i], transcript_id))
|
||||
if completion.choices and completion.choices[0].message:
|
||||
next_translation = completion.choices[0].message.content
|
||||
else:
|
||||
# Translation error
|
||||
results.append(("", transcript_id))
|
||||
return results
|
||||
logging.error(f"Translation error! Server response: {completion}")
|
||||
next_translation = "Translation error, see logs!"
|
||||
|
||||
@staticmethod
|
||||
def _parse_batch_response(response: str, expected_count: int) -> List[str]:
|
||||
"""Parse a numbered batch response like '[1] text\\n[2] text' into a list of strings."""
|
||||
# Split on [N] markers — re.split with a group returns: [before, group1, after1, group2, after2, ...]
|
||||
parts = re.split(r'\[(\d+)\]\s*', response)
|
||||
self.translation.emit(next_translation, transcript_id)
|
||||
|
||||
translations = {}
|
||||
for i in range(1, len(parts) - 1, 2):
|
||||
num = int(parts[i])
|
||||
text = parts[i + 1].strip()
|
||||
translations[num] = text
|
||||
|
||||
return [
|
||||
translations.get(i, "")
|
||||
for i in range(1, expected_count + 1)
|
||||
]
|
||||
|
||||
def start(self):
|
||||
logging.debug("Starting translation queue")
|
||||
|
||||
while True:
|
||||
item = self.queue.get() # Block until item available
|
||||
|
||||
# Check for sentinel value (None means stop)
|
||||
if item is None:
|
||||
logging.debug("Translation queue received stop signal")
|
||||
break
|
||||
|
||||
# Collect a batch: start with the first item, then drain more
|
||||
batch = [item]
|
||||
stop_after_batch = False
|
||||
while len(batch) < BATCH_SIZE:
|
||||
try:
|
||||
next_item = self.queue.get_nowait()
|
||||
if next_item is None:
|
||||
stop_after_batch = True
|
||||
break
|
||||
batch.append(next_item)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
if len(batch) == 1:
|
||||
transcript, transcript_id = batch[0]
|
||||
translation, tid = self._translate_single(transcript, transcript_id)
|
||||
self.translation.emit(translation, tid)
|
||||
else:
|
||||
logging.debug(f"Translating batch of {len(batch)} in single request")
|
||||
results = self._translate_batch(batch)
|
||||
for translation, tid in results:
|
||||
self.translation.emit(translation, tid)
|
||||
|
||||
if stop_after_batch:
|
||||
logging.debug("Translation queue received stop signal")
|
||||
break
|
||||
|
||||
logging.debug("Translation queue stopped")
|
||||
self.finished.emit()
|
||||
|
||||
def on_transcription_options_changed(
|
||||
|
|
@ -194,5 +91,4 @@ class Translator(QObject):
|
|||
self.queue.put((transcript, transcript_id))
|
||||
|
||||
def stop(self):
|
||||
# Send sentinel value to unblock and stop the worker thread
|
||||
self.queue.put(None)
|
||||
self.is_running = False
|
||||
|
|
|
|||
|
|
@ -1,163 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import platform
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from PyQt6.QtCore import QObject, pyqtSignal, QUrl
|
||||
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
|
||||
from buzz.__version__ import VERSION
|
||||
from buzz.settings.settings import Settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateInfo:
|
||||
version: str
|
||||
release_notes: str
|
||||
download_urls: list
|
||||
|
||||
class UpdateChecker(QObject):
|
||||
update_available = pyqtSignal(object)
|
||||
|
||||
VERSION_JSON_URL = "https://github.com/chidiwilliams/buzz/releases/latest/download/version_info.json"
|
||||
|
||||
CHECK_INTERVAL_DAYS = 7
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
network_manager: Optional[QNetworkAccessManager] = None,
|
||||
parent: Optional[QObject] = None
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.settings = settings
|
||||
|
||||
if network_manager is None:
|
||||
network_manager = QNetworkAccessManager(self)
|
||||
self.network_manager = network_manager
|
||||
self.network_manager.finished.connect(self._on_reply_finished)
|
||||
|
||||
def should_check_for_updates(self) -> bool:
|
||||
"""Check if we are on Windows/macOS and if 7 days passed"""
|
||||
system = platform.system()
|
||||
if system not in ("Windows", "Darwin"):
|
||||
logging.debug("Skipping update check on linux")
|
||||
return False
|
||||
|
||||
last_check = self.settings.value(
|
||||
Settings.Key.LAST_UPDATE_CHECK,
|
||||
"",
|
||||
)
|
||||
|
||||
if last_check:
|
||||
try:
|
||||
last_check_date = datetime.fromisoformat(last_check)
|
||||
days_since_check = (datetime.now() - last_check_date).days
|
||||
if days_since_check < self.CHECK_INTERVAL_DAYS:
|
||||
logging.debug(
|
||||
f"Skipping update check, last checked {days_since_check} days ago"
|
||||
)
|
||||
return False
|
||||
except ValueError:
|
||||
#Invalid date format
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def check_for_updates(self) -> None:
|
||||
"""Start the network request"""
|
||||
if not self.should_check_for_updates():
|
||||
return
|
||||
|
||||
logging.info("Checking for updates...")
|
||||
|
||||
url = QUrl(self.VERSION_JSON_URL)
|
||||
request = QNetworkRequest(url)
|
||||
self.network_manager.get(request)
|
||||
|
||||
def _on_reply_finished(self, reply: QNetworkReply) -> None:
|
||||
"""Handles the network reply for version.json fetch"""
|
||||
self.settings.set_value(
|
||||
Settings.Key.LAST_UPDATE_CHECK,
|
||||
datetime.now().isoformat()
|
||||
)
|
||||
|
||||
if reply.error() != QNetworkReply.NetworkError.NoError:
|
||||
error_msg = f"Failed to check for updates: {reply.errorString()}"
|
||||
logging.error(error_msg)
|
||||
reply.deleteLater()
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(reply.readAll().data().decode("utf-8"))
|
||||
reply.deleteLater()
|
||||
|
||||
remote_version = data.get("version", "")
|
||||
release_notes = data.get("release_notes", "")
|
||||
download_urls = data.get("download_urls", {})
|
||||
|
||||
#Get the download url for current platform
|
||||
download_url = self._get_download_url(download_urls)
|
||||
|
||||
if self._is_newer_version(remote_version):
|
||||
logging.info(f"Update available: {remote_version}")
|
||||
|
||||
#Store the available version
|
||||
self.settings.set_value(
|
||||
Settings.Key.UPDATE_AVAILABLE_VERSION,
|
||||
remote_version
|
||||
)
|
||||
|
||||
update_info = UpdateInfo(
|
||||
version=remote_version,
|
||||
release_notes=release_notes,
|
||||
download_urls=download_url
|
||||
)
|
||||
self.update_available.emit(update_info)
|
||||
|
||||
else:
|
||||
logging.info("No update available")
|
||||
self.settings.set_value(
|
||||
Settings.Key.UPDATE_AVAILABLE_VERSION,
|
||||
""
|
||||
)
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
error_msg = f"Failed to parse version info: {e}"
|
||||
logging.error(error_msg)
|
||||
|
||||
def _get_download_url(self, download_urls: dict) -> list:
|
||||
system = platform.system()
|
||||
machine = platform.machine().lower()
|
||||
|
||||
if system == "Windows":
|
||||
urls = download_urls.get("windows_x64", [])
|
||||
elif system == "Darwin":
|
||||
if machine in ("arm64", "aarch64"):
|
||||
urls = download_urls.get("macos_arm", [])
|
||||
else:
|
||||
urls = download_urls.get("macos_x86", [])
|
||||
else:
|
||||
urls = []
|
||||
|
||||
return urls if isinstance(urls, list) else [urls]
|
||||
|
||||
def _is_newer_version(self, remote_version: str) -> bool:
|
||||
"""Compare remote version with current version"""
|
||||
try:
|
||||
current_parts = [int(x) for x in VERSION.split(".")]
|
||||
remote_parts = [int(x) for x in remote_version.split(".")]
|
||||
|
||||
#pad with zeros if needed
|
||||
while len(current_parts) < len(remote_parts):
|
||||
current_parts.append(0)
|
||||
while len(remote_parts) < len(current_parts):
|
||||
remote_parts.append(0)
|
||||
|
||||
return remote_parts > current_parts
|
||||
|
||||
except ValueError:
|
||||
logging.error(f"Invalid version format: {VERSION} or {remote_version}")
|
||||
return False
|
||||
|
|
@ -1,10 +1,7 @@
|
|||
import subprocess
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from subprocess import run
|
||||
|
||||
from buzz.assets import APP_BASE_DIR
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
|
@ -13,8 +10,6 @@ HOP_LENGTH = 160
|
|||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
|
||||
app_env = os.environ.copy()
|
||||
app_env['PATH'] = os.pathsep.join([os.path.join(APP_BASE_DIR, "_internal")] + [app_env['PATH']])
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
|
|
@ -49,19 +44,7 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||
"-"
|
||||
]
|
||||
# fmt: on
|
||||
if sys.platform == "win32":
|
||||
si = subprocess.STARTUPINFO()
|
||||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
si.wShowWindow = subprocess.SW_HIDE
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
startupinfo=si,
|
||||
env=app_env,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
result = run(cmd, capture_output=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
logging.warning(f"FFMPEG audio load warning. Process return code was not zero: {result.returncode}")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
from platformdirs import user_log_dir
|
||||
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import Qt, QUrl
|
||||
|
|
@ -81,9 +80,6 @@ class AboutDialog(QDialog):
|
|||
self.check_updates_button = QPushButton(_("Check for updates"), self)
|
||||
self.check_updates_button.clicked.connect(self.on_click_check_for_updates)
|
||||
|
||||
self.show_logs_button = QPushButton(_("Show logs"), self)
|
||||
self.show_logs_button.clicked.connect(self.on_click_show_logs)
|
||||
|
||||
button_box = QDialogButtonBox(
|
||||
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Close), self
|
||||
)
|
||||
|
|
@ -94,21 +90,15 @@ class AboutDialog(QDialog):
|
|||
layout.addWidget(buzz_label)
|
||||
layout.addWidget(version_label)
|
||||
layout.addWidget(self.check_updates_button)
|
||||
layout.addWidget(self.show_logs_button)
|
||||
layout.addWidget(button_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.setMinimumWidth(350)
|
||||
|
||||
def on_click_check_for_updates(self):
|
||||
url = QUrl(self.GITHUB_API_LATEST_RELEASE_URL)
|
||||
self.network_access_manager.get(QNetworkRequest(url))
|
||||
self.check_updates_button.setDisabled(True)
|
||||
|
||||
def on_click_show_logs(self):
|
||||
log_dir = user_log_dir(appname="Buzz")
|
||||
QDesktopServices.openUrl(QUrl.fromLocalFile(log_dir))
|
||||
|
||||
def on_latest_release_reply(self, reply: QNetworkReply):
|
||||
if reply.error() == QNetworkReply.NetworkError.NoError:
|
||||
response = json.loads(reply.readAll().data())
|
||||
|
|
|
|||
|
|
@ -1,15 +1,9 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import locale
|
||||
import platform
|
||||
import darkdetect
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from PyQt6.QtCore import Qt
|
||||
from PyQt6.QtGui import QFont
|
||||
from PyQt6.QtWidgets import QApplication, QStyleFactory
|
||||
from PyQt6.QtWidgets import QApplication
|
||||
from PyQt6.QtGui import QPalette, QColor
|
||||
|
||||
from buzz.__version__ import VERSION
|
||||
from buzz.db.dao.transcription_dao import TranscriptionDAO
|
||||
|
|
@ -32,30 +26,139 @@ class Application(QApplication):
|
|||
self.setApplicationVersion(VERSION)
|
||||
self.hide_main_window = False
|
||||
|
||||
if darkdetect.isDark():
|
||||
self.styleHints().setColorScheme(Qt.ColorScheme.Dark)
|
||||
self.setStyleSheet("QCheckBox::indicator:unchecked { border: 1px solid white; }")
|
||||
if sys.platform.startswith("win") and darkdetect.isDark():
|
||||
palette = QPalette()
|
||||
palette.setColor(QPalette.ColorRole.Window, QColor("#121212"))
|
||||
palette.setColor(QPalette.ColorRole.WindowText, QColor("#ffffff"))
|
||||
palette.setColor(QPalette.ColorRole.Base, QColor("#1e1e1e"))
|
||||
palette.setColor(QPalette.ColorRole.AlternateBase, QColor("#2e2e2e"))
|
||||
palette.setColor(QPalette.ColorRole.ToolTipBase, QColor("#ffffff"))
|
||||
palette.setColor(QPalette.ColorRole.ToolTipText, QColor("#000000"))
|
||||
palette.setColor(QPalette.ColorRole.Text, QColor("#ffffff"))
|
||||
palette.setColor(QPalette.ColorRole.Button, QColor("#1e1e1e"))
|
||||
palette.setColor(QPalette.ColorRole.ButtonText, QColor("#ffffff"))
|
||||
palette.setColor(QPalette.ColorRole.BrightText, QColor("#ff0000"))
|
||||
palette.setColor(QPalette.ColorRole.HighlightedText, QColor("#000000"))
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
self.setStyle(QStyleFactory.create("Fusion"))
|
||||
self.setPalette(palette)
|
||||
|
||||
# For Windows 11
|
||||
stylesheet = """
|
||||
QWidget {
|
||||
background-color: #121212;
|
||||
color: #ffffff;
|
||||
}
|
||||
QPushButton {
|
||||
background-color: #1e1e1e;
|
||||
color: #ffffff;
|
||||
}
|
||||
QPushButton:hover {
|
||||
background-color: #2e2e2e;
|
||||
}
|
||||
QHeaderView::section {
|
||||
background-color: #1e1e1e;
|
||||
color: #ffffff;
|
||||
font-weight: bold;
|
||||
}
|
||||
QToolBar {
|
||||
border: 1px solid #2e2e2e;
|
||||
}
|
||||
QTabBar::tab {
|
||||
background-color: #1e1e1e;
|
||||
color: #ffffff;
|
||||
}
|
||||
QTabBar::tab:selected {
|
||||
background-color: #2e2e2e;
|
||||
}
|
||||
QLineEdit, QTextEdit, QPlainTextEdit, QSpinBox, QDoubleSpinBox,
|
||||
QTabWidget::pane, QFormLayout, QHBoxLayout, QVBoxLayout, QTreeWidget,
|
||||
QTableView, QGroupBox {
|
||||
border: 1px solid #2e2e2e;
|
||||
}
|
||||
QLineEdit:focus, QTextEdit:focus, QPlainTextEdit:focus, QSpinBox:focus, QDoubleSpinBox:focus,
|
||||
QTabWidget::pane:focus, QFormLayout:focus, QHBoxLayout:focus, QVBoxLayout:focus, QTreeWidget:focus,
|
||||
QTableView:focus, QGroupBox:focus {
|
||||
border: 1px solid #4e4e4e;
|
||||
}
|
||||
QMenuBar {
|
||||
background-color: #1e1e1e;
|
||||
color: #ffffff;
|
||||
}
|
||||
QMenuBar::item {
|
||||
background-color: #1e1e1e;
|
||||
color: #ffffff;
|
||||
}
|
||||
QMenuBar::item:selected {
|
||||
background-color: #2e2e2e;
|
||||
}
|
||||
QMenu::item {
|
||||
background-color: #1e1e1e;
|
||||
color: #ffffff;
|
||||
}
|
||||
QMenu::item:selected {
|
||||
background-color: #2e2e2e;
|
||||
}
|
||||
QMenu::item:hover {
|
||||
background-color: #2e2e2e;
|
||||
}
|
||||
QToolButton {
|
||||
background-color: transparent;
|
||||
min-height: 30px;
|
||||
min-width: 30px;
|
||||
}
|
||||
QToolButton:hover {
|
||||
background-color: #2e2e2e;
|
||||
}
|
||||
QScrollBar:vertical {
|
||||
background-color: #1e1e1e;
|
||||
width: 16px;
|
||||
margin: 16px 0 16px 0;
|
||||
}
|
||||
QScrollBar::handle:vertical {
|
||||
background-color: #2e2e2e;
|
||||
min-height: 20px;
|
||||
}
|
||||
QScrollBar::add-line:vertical {
|
||||
background-color: #1e1e1e;
|
||||
height: 16px;
|
||||
subcontrol-position: bottom;
|
||||
subcontrol-origin: margin;
|
||||
}
|
||||
QScrollBar::sub-line:vertical {
|
||||
background-color: #1e1e1e;
|
||||
height: 16px;
|
||||
subcontrol-position: top;
|
||||
subcontrol-origin: margin;
|
||||
}
|
||||
QScrollBar:horizontal {
|
||||
background-color: #1e1e1e;
|
||||
height: 16px;
|
||||
margin: 0 16px 0 16px;
|
||||
}
|
||||
QScrollBar::handle:horizontal {
|
||||
background-color: #2e2e2e;
|
||||
min-width: 20px;
|
||||
}
|
||||
QScrollBar::add-line:horizontal {
|
||||
background-color: #1e1e1e;
|
||||
width: 16px;
|
||||
subcontrol-position: right;
|
||||
subcontrol-origin: margin;
|
||||
}
|
||||
QScrollBar::sub-line:horizontal {
|
||||
background-color: #1e1e1e;
|
||||
width: 16px;
|
||||
subcontrol-position: left;
|
||||
subcontrol-origin: margin;
|
||||
}
|
||||
QScrollBar::sub-page:horizontal, QScrollBar::add-page:horizontal,
|
||||
QScrollBar::sub-page:vertical, QScrollBar::add-page:vertical {
|
||||
background-color: #1e1e1e;
|
||||
}
|
||||
"""
|
||||
self.setStyleSheet(stylesheet)
|
||||
|
||||
self.settings = Settings()
|
||||
logging.debug(f"Settings filename: {self.settings.settings.fileName()}")
|
||||
|
||||
# Set BUZZ_FORCE_CPU environment variable if Force CPU setting is enabled
|
||||
force_cpu_enabled = self.settings.value(
|
||||
key=Settings.Key.FORCE_CPU, default_value=False
|
||||
)
|
||||
if force_cpu_enabled:
|
||||
os.environ["BUZZ_FORCE_CPU"] = "true"
|
||||
|
||||
# Set BUZZ_REDUCE_GPU_MEMORY environment variable if Reduce GPU RAM setting is enabled
|
||||
reduce_gpu_memory_enabled = self.settings.value(
|
||||
key=Settings.Key.REDUCE_GPU_MEMORY, default_value=False
|
||||
)
|
||||
if reduce_gpu_memory_enabled:
|
||||
os.environ["BUZZ_REDUCE_GPU_MEMORY"] = "true"
|
||||
|
||||
font_size = self.settings.value(
|
||||
key=Settings.Key.FONT_SIZE, default_value=self.font().pointSize()
|
||||
)
|
||||
|
|
@ -65,34 +168,13 @@ class Application(QApplication):
|
|||
else:
|
||||
self.setFont(QFont(self.font().family(), font_size))
|
||||
|
||||
self.db = setup_app_db()
|
||||
db = setup_app_db()
|
||||
transcription_service = TranscriptionService(
|
||||
TranscriptionDAO(self.db), TranscriptionSegmentDAO(self.db)
|
||||
TranscriptionDAO(db), TranscriptionSegmentDAO(db)
|
||||
)
|
||||
|
||||
self.window = MainWindow(transcription_service)
|
||||
|
||||
disable_telemetry = os.getenv("BUZZ_DISABLE_TELEMETRY", None)
|
||||
|
||||
if not disable_telemetry:
|
||||
posthog = Posthog(project_api_key='phc_NqZQUw8NcxfSXsbtk5eCFylmCQpp4FuNnd6ocPAzg2f',
|
||||
host='https://us.i.posthog.com')
|
||||
posthog.capture(distinct_id=self.settings.get_user_identifier(), event="app_launched", properties={
|
||||
"app": VERSION,
|
||||
"locale": locale.getlocale(),
|
||||
"system": platform.system(),
|
||||
"release": platform.release(),
|
||||
"machine": platform.machine(),
|
||||
"version": platform.version(),
|
||||
})
|
||||
|
||||
logging.debug(f"Launching Buzz: {VERSION}, "
|
||||
f"locale: {locale.getlocale()}, "
|
||||
f"system: {platform.system()}, "
|
||||
f"release: {platform.release()}, "
|
||||
f"machine: {platform.machine()}, "
|
||||
f"version: {platform.version()}, ")
|
||||
|
||||
def show_main_window(self):
|
||||
if not self.hide_main_window:
|
||||
self.window.show()
|
||||
|
|
@ -100,7 +182,3 @@ class Application(QApplication):
|
|||
def add_task(self, task: FileTranscriptionTask, quit_on_complete: bool = False):
|
||||
self.window.quit_on_complete = quit_on_complete
|
||||
self.window.add_task(task)
|
||||
|
||||
def close_database(self):
|
||||
from buzz.db.db import close_app_db
|
||||
close_app_db()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import Qt, QRect
|
||||
from PyQt6.QtCore import Qt
|
||||
from PyQt6.QtGui import QColor, QPainter
|
||||
from PyQt6.QtWidgets import QWidget
|
||||
|
||||
from buzz.locale import _
|
||||
|
||||
|
||||
class AudioMeterWidget(QWidget):
|
||||
current_amplitude: float
|
||||
|
|
@ -22,19 +20,15 @@ class AudioMeterWidget(QWidget):
|
|||
def __init__(self, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
self.setMinimumWidth(10)
|
||||
self.setFixedHeight(56)
|
||||
self.setFixedHeight(16)
|
||||
|
||||
self.BARS_HEIGHT = 28
|
||||
# Extra padding to fix layout
|
||||
self.PADDING_TOP = 14
|
||||
self.PADDING_TOP = 3
|
||||
|
||||
self.current_amplitude = 0.0
|
||||
|
||||
self.average_amplitude = 0.0
|
||||
self.queue_size = 0
|
||||
|
||||
self.MINIMUM_AMPLITUDE = 0.00005 # minimum amplitude to show the first bar
|
||||
self.AMPLITUDE_SCALE_FACTOR = 10 # scale the amplitudes such that 1/AMPLITUDE_SCALE_FACTOR will show all bars
|
||||
self.AMPLITUDE_SCALE_FACTOR = 15 # scale the amplitudes such that 1/AMPLITUDE_SCALE_FACTOR will show all bars
|
||||
|
||||
if self.palette().window().color().black() > 127:
|
||||
self.BAR_INACTIVE_COLOR = QColor("#555")
|
||||
|
|
@ -64,39 +58,18 @@ class AudioMeterWidget(QWidget):
|
|||
center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)),
|
||||
rect.top() + self.PADDING_TOP,
|
||||
self.BAR_WIDTH,
|
||||
self.BARS_HEIGHT - self.PADDING_TOP,
|
||||
rect.height() - self.PADDING_TOP,
|
||||
)
|
||||
# draw to right
|
||||
painter.drawRect(
|
||||
center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))),
|
||||
rect.top() + self.PADDING_TOP,
|
||||
self.BAR_WIDTH,
|
||||
self.BARS_HEIGHT - self.PADDING_TOP,
|
||||
rect.height() - self.PADDING_TOP,
|
||||
)
|
||||
|
||||
text_rect = QRect(rect.left(), self.BARS_HEIGHT, rect.width(), rect.height() - self.BARS_HEIGHT)
|
||||
painter.setPen(self.BAR_ACTIVE_COLOR)
|
||||
average_volume_label = _("Average volume")
|
||||
queue_label = _("Queue")
|
||||
painter.drawText(text_rect, Qt.AlignmentFlag.AlignCenter,
|
||||
f"{average_volume_label}: {self.average_amplitude:.4f} {queue_label}: {self.queue_size}")
|
||||
|
||||
def reset_amplitude(self):
|
||||
self.current_amplitude = 0.0
|
||||
self.average_amplitude = 0.0
|
||||
self.queue_size = 0
|
||||
self.repaint()
|
||||
|
||||
def update_amplitude(self, amplitude: float):
|
||||
self.current_amplitude = max(
|
||||
amplitude, self.current_amplitude * self.SMOOTHING_FACTOR
|
||||
)
|
||||
self.update()
|
||||
|
||||
def update_average_amplitude(self, amplitude: float):
|
||||
self.average_amplitude = amplitude
|
||||
self.update()
|
||||
|
||||
def update_queue_size(self, size: int):
|
||||
self.queue_size = size
|
||||
self.update()
|
||||
self.repaint()
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
import logging
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import QTime, QUrl, Qt, pyqtSignal
|
||||
from PyQt6.QtMultimedia import QAudioOutput, QMediaPlayer, QMediaDevices
|
||||
from PyQt6.QtWidgets import QWidget, QSlider, QPushButton, QLabel, QHBoxLayout, QVBoxLayout
|
||||
from PyQt6.QtMultimedia import QAudioOutput, QMediaPlayer
|
||||
from PyQt6.QtWidgets import QWidget, QSlider, QPushButton, QLabel, QHBoxLayout
|
||||
|
||||
from buzz.widgets.icon import PlayIcon, PauseIcon
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.transcriber.file_transcriber import is_video_file
|
||||
|
||||
|
||||
class AudioPlayer(QWidget):
|
||||
|
|
@ -21,51 +18,17 @@ class AudioPlayer(QWidget):
|
|||
self.position_ms = 0
|
||||
self.duration_ms = 0
|
||||
self.invalid_media = None
|
||||
self.is_looping = False # Flag to prevent recursive position changes
|
||||
self.is_slider_dragging = False # Flag to track if use is dragging slider
|
||||
|
||||
# Initialize settings
|
||||
self.settings = Settings()
|
||||
|
||||
self.is_video = is_video_file(file_path)
|
||||
|
||||
self.audio_output = QAudioOutput()
|
||||
self.audio_output.setVolume(100)
|
||||
|
||||
# Log audio device info for debugging
|
||||
default_device = QMediaDevices.defaultAudioOutput()
|
||||
if default_device.isNull():
|
||||
logging.warning("No default audio output device found!")
|
||||
else:
|
||||
logging.info(f"Audio output device: {default_device.description()}")
|
||||
|
||||
audio_outputs = QMediaDevices.audioOutputs()
|
||||
logging.info(f"Available audio outputs: {[d.description() for d in audio_outputs]}")
|
||||
|
||||
self.media_player = QMediaPlayer()
|
||||
self.media_player.setSource(QUrl.fromLocalFile(file_path))
|
||||
self.media_player.setAudioOutput(self.audio_output)
|
||||
|
||||
if self.is_video:
|
||||
from PyQt6.QtMultimediaWidgets import QVideoWidget
|
||||
self.video_widget = QVideoWidget(self)
|
||||
self.media_player.setVideoOutput(self.video_widget)
|
||||
else:
|
||||
self.video_widget = None
|
||||
|
||||
# Speed control moved to transcription viewer - just set default rate
|
||||
saved_rate = self.settings.value(Settings.Key.AUDIO_PLAYBACK_RATE, 1.0, float)
|
||||
saved_rate = max(0.1, min(5.0, saved_rate)) # Ensure valid range
|
||||
self.media_player.setPlaybackRate(saved_rate)
|
||||
|
||||
self.scrubber = QSlider(Qt.Orientation.Horizontal)
|
||||
self.scrubber.setRange(0, 0)
|
||||
self.scrubber.sliderMoved.connect(self.on_slider_moved)
|
||||
self.scrubber.sliderPressed.connect(self.on_slider_pressed)
|
||||
self.scrubber.sliderReleased.connect(self.on_slider_released)
|
||||
|
||||
# Track if user is dragging the slider
|
||||
self.is_slider_dragging = False
|
||||
|
||||
self.play_icon = PlayIcon(self)
|
||||
self.pause_icon = PauseIcon(self)
|
||||
|
|
@ -73,39 +36,22 @@ class AudioPlayer(QWidget):
|
|||
self.play_button = QPushButton("")
|
||||
self.play_button.setIcon(self.play_icon)
|
||||
self.play_button.clicked.connect(self.toggle_play)
|
||||
self.play_button.setMaximumWidth(40) # Match other button widths
|
||||
self.play_button.setMinimumHeight(30) # Match other button heights
|
||||
|
||||
self.time_label = QLabel()
|
||||
self.time_label.setAlignment(Qt.AlignmentFlag.AlignRight)
|
||||
|
||||
# Create main layout - simplified without speed controls
|
||||
if self.is_video:
|
||||
#Vertical layout for video
|
||||
main_layout = QVBoxLayout()
|
||||
main_layout.addWidget(self.video_widget, stretch=1) # As video takes more space
|
||||
layout = QHBoxLayout()
|
||||
layout.addWidget(self.play_button, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
layout.addWidget(self.scrubber, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
layout.addWidget(self.time_label, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
|
||||
controls_layout = QHBoxLayout()
|
||||
controls_layout.addWidget(self.play_button, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
controls_layout.addWidget(self.scrubber, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
controls_layout.addWidget(self.time_label, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
|
||||
main_layout.addLayout(controls_layout)
|
||||
else:
|
||||
# Horizontal layout for audio only
|
||||
main_layout = QHBoxLayout()
|
||||
main_layout.addWidget(self.play_button, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
main_layout.addWidget(self.scrubber, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
main_layout.addWidget(self.time_label, alignment=Qt.AlignmentFlag.AlignVCenter)
|
||||
|
||||
self.setLayout(main_layout)
|
||||
self.setLayout(layout)
|
||||
|
||||
# Connect media player signals to the corresponding slots
|
||||
self.media_player.durationChanged.connect(self.on_duration_changed)
|
||||
self.media_player.positionChanged.connect(self.on_position_changed)
|
||||
self.media_player.playbackStateChanged.connect(self.on_playback_state_changed)
|
||||
self.media_player.mediaStatusChanged.connect(self.on_media_status_changed)
|
||||
self.media_player.errorOccurred.connect(self.on_error_occurred)
|
||||
|
||||
self.on_duration_changed(self.media_player.duration())
|
||||
|
||||
|
|
@ -115,27 +61,17 @@ class AudioPlayer(QWidget):
|
|||
self.update_time_label()
|
||||
|
||||
def on_position_changed(self, position_ms: int):
|
||||
# Don't update slider if user is currently dragging it
|
||||
if not self.is_slider_dragging:
|
||||
self.scrubber.blockSignals(True)
|
||||
self.scrubber.setValue(position_ms)
|
||||
self.scrubber.blockSignals(False)
|
||||
|
||||
self.scrubber.setValue(position_ms)
|
||||
self.position_ms = position_ms
|
||||
self.position_ms_changed.emit(self.position_ms)
|
||||
self.update_time_label()
|
||||
|
||||
# If a range has been selected as we've reached the end of the range,
|
||||
# loop back to the start of the range
|
||||
if self.range_ms is not None and not self.is_looping:
|
||||
if self.range_ms is not None:
|
||||
start_range_ms, end_range_ms = self.range_ms
|
||||
# Check if we're at or past the end of the range (with small buffer for precision)
|
||||
if position_ms >= (end_range_ms - 50): # Within 50ms of end
|
||||
logging.debug(f"🔄 LOOP: Reached end {end_range_ms}ms, jumping to start {start_range_ms}ms")
|
||||
self.is_looping = True # Set flag to prevent recursion
|
||||
if position_ms > end_range_ms:
|
||||
self.set_position(start_range_ms)
|
||||
# Reset flag immediately after setting position
|
||||
self.is_looping = False
|
||||
|
||||
def on_playback_state_changed(self, state: QMediaPlayer.PlaybackState):
|
||||
if state == QMediaPlayer.PlaybackState.PlayingState:
|
||||
|
|
@ -144,16 +80,12 @@ class AudioPlayer(QWidget):
|
|||
self.play_button.setIcon(self.play_icon)
|
||||
|
||||
def on_media_status_changed(self, status: QMediaPlayer.MediaStatus):
|
||||
logging.debug(f"Media status changed: {status}")
|
||||
match status:
|
||||
case QMediaPlayer.MediaStatus.InvalidMedia:
|
||||
self.set_invalid_media(True)
|
||||
case QMediaPlayer.MediaStatus.LoadedMedia:
|
||||
self.set_invalid_media(False)
|
||||
|
||||
def on_error_occurred(self, error: QMediaPlayer.Error, error_string: str):
|
||||
logging.error(f"Media player error: {error} - {error_string}")
|
||||
|
||||
def set_invalid_media(self, invalid_media: bool):
|
||||
self.invalid_media = invalid_media
|
||||
if self.invalid_media:
|
||||
|
|
@ -161,10 +93,6 @@ class AudioPlayer(QWidget):
|
|||
self.scrubber.setRange(0, 1)
|
||||
self.scrubber.setDisabled(True)
|
||||
self.time_label.setDisabled(True)
|
||||
else:
|
||||
self.play_button.setEnabled(True)
|
||||
self.scrubber.setEnabled(True)
|
||||
self.time_label.setEnabled(True)
|
||||
|
||||
def toggle_play(self):
|
||||
if self.media_player.playbackState() == QMediaPlayer.PlaybackState.PlayingState:
|
||||
|
|
@ -173,41 +101,13 @@ class AudioPlayer(QWidget):
|
|||
self.media_player.play()
|
||||
|
||||
def set_range(self, range_ms: Tuple[int, int]):
|
||||
"""Set a loop range. Only jump to start if current position is outside the range."""
|
||||
self.range_ms = range_ms
|
||||
start_range_ms, end_range_ms = range_ms
|
||||
|
||||
# Only jump to start if current position is outside the range
|
||||
if self.position_ms < start_range_ms or self.position_ms > end_range_ms:
|
||||
logging.debug(f"🔄 LOOP: Position {self.position_ms}ms outside range, jumping to {start_range_ms}ms")
|
||||
self.set_position(start_range_ms)
|
||||
|
||||
def clear_range(self):
|
||||
"""Clear the current loop range"""
|
||||
self.range_ms = None
|
||||
|
||||
def _reset_looping_flag(self):
|
||||
"""Reset the looping flag"""
|
||||
self.is_looping = False
|
||||
self.set_position(range_ms[0])
|
||||
|
||||
def on_slider_moved(self, position_ms: int):
|
||||
self.set_position(position_ms)
|
||||
# Only clear range if scrubbed significantly outside the current range
|
||||
if self.range_ms is not None:
|
||||
start_range_ms, end_range_ms = self.range_ms
|
||||
# Clear range if scrubbed more than 2 seconds outside the range
|
||||
if position_ms < (start_range_ms - 2000) or position_ms > (end_range_ms + 2000):
|
||||
self.range_ms = None
|
||||
|
||||
def on_slider_pressed(self):
|
||||
"""Called when the user starts dragging the slider"""
|
||||
self.is_slider_dragging = True
|
||||
|
||||
def on_slider_released(self):
|
||||
"""Called when user releases the slider"""
|
||||
self.is_slider_dragging = False
|
||||
# Update the position where user released
|
||||
self.set_position(self.scrubber.value())
|
||||
# Reset range if slider is scrubbed manually
|
||||
self.range_ms = None
|
||||
|
||||
def set_position(self, position_ms: int):
|
||||
self.media_player.setPosition(position_ms)
|
||||
|
|
|
|||
|
|
@ -82,10 +82,6 @@ class ResizeIcon(Icon):
|
|||
def __init__(self, parent: QWidget):
|
||||
super().__init__(get_path("assets/resize_black.svg"), parent)
|
||||
|
||||
class SpeakerIdentificationIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(get_path("assets/speaker-identification.svg"), parent)
|
||||
|
||||
class VisibilityIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(
|
||||
|
|
@ -93,32 +89,6 @@ class VisibilityIcon(Icon):
|
|||
)
|
||||
|
||||
|
||||
class ScrollToCurrentIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(
|
||||
get_path("assets/visibility_FILL0_wght700_GRAD0_opsz48.svg"), parent
|
||||
)
|
||||
|
||||
class NewWindowIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(get_path("assets/icons/new-window.svg"), parent)
|
||||
|
||||
|
||||
class FullscreenIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(get_path("assets/icons/fullscreen.svg"), parent)
|
||||
|
||||
|
||||
class ColorBackgroundIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(get_path("assets/icons/color-background.svg"), parent)
|
||||
|
||||
|
||||
class TextColorIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(get_path("assets/icons/gui-text-color.svg"), parent)
|
||||
|
||||
|
||||
BUZZ_ICON_PATH = get_path("assets/buzz.ico")
|
||||
BUZZ_LARGE_ICON_PATH = get_path("assets/buzz-icon-1024.png")
|
||||
|
||||
|
|
@ -126,7 +96,5 @@ INFO_ICON_PATH = get_path("assets/info-circle.svg")
|
|||
RECORD_ICON_PATH = get_path("assets/mic_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
EXPAND_ICON_PATH = get_path("assets/open_in_full_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
ADD_ICON_PATH = get_path("assets/add_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
URL_ICON_PATH = get_path("assets/url.svg")
|
||||
TRASH_ICON_PATH = get_path("assets/delete_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
CANCEL_ICON_PATH = get_path("assets/cancel_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
UPDATE_ICON_PATH = get_path("assets/update_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
from PyQt6.QtGui import QIcon, QPixmap, QPainter, QPalette
|
||||
from PyQt6.QtCore import QSize
|
||||
from PyQt6.QtSvg import QSvgRenderer
|
||||
import os
|
||||
from buzz.assets import APP_BASE_DIR
|
||||
|
||||
class PresentationIcon:
|
||||
"Icons for presentation window controls"
|
||||
def __init__(self, parent, svg_path: str, color: str = None):
|
||||
self.parent = parent
|
||||
self.svg_path = svg_path
|
||||
self.color = color or self.get_default_color()
|
||||
|
||||
|
||||
def get_default_color(self) -> str:
|
||||
"""Get default icon color based on theme"""
|
||||
palette = self.parent.palette()
|
||||
is_dark = palette.window().color().black() > 127
|
||||
|
||||
return "#EEE" if is_dark else "#555"
|
||||
|
||||
def get_icon(self) -> QIcon:
|
||||
"""Load SVG icon and return as QIcon"""
|
||||
#Load from asset first
|
||||
full_path = os.path.join(APP_BASE_DIR, "assets", "icons", os.path.basename(self.svg_path))
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
pixmap = QPixmap(24, 24)
|
||||
pixmap.fill(self.color)
|
||||
|
||||
return QIcon(pixmap)
|
||||
|
||||
#Load SVG
|
||||
renderer = QSvgRenderer(full_path)
|
||||
pixmap = QPixmap(24, 24)
|
||||
pixmap.fill(Qt.GlobalColor.transparent)
|
||||
painter = QPainter(pixmap)
|
||||
renderer.render(painter)
|
||||
painter.end()
|
||||
|
||||
return QIcon(pixmap)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
import os
|
||||
import logging
|
||||
import sounddevice
|
||||
import keyring
|
||||
from typing import Tuple, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (
|
||||
Qt,
|
||||
QThread,
|
||||
QModelIndex,
|
||||
pyqtSignal
|
||||
)
|
||||
|
||||
from PyQt6.QtGui import QIcon
|
||||
from PyQt6.QtWidgets import (
|
||||
QApplication,
|
||||
|
|
@ -24,8 +23,6 @@ from buzz.db.service.transcription_service import TranscriptionService
|
|||
from buzz.file_transcriber_queue_worker import FileTranscriberQueueWorker
|
||||
from buzz.locale import _
|
||||
from buzz.settings.settings import APP_NAME, Settings
|
||||
from buzz.update_checker import UpdateChecker, UpdateInfo
|
||||
from buzz.widgets.update_dialog import UpdateDialog
|
||||
from buzz.settings.shortcuts import Shortcuts
|
||||
from buzz.store.keyring_store import set_password, Key
|
||||
from buzz.transcriber.transcriber import (
|
||||
|
|
@ -39,11 +36,11 @@ from buzz.widgets.icon import BUZZ_ICON_PATH
|
|||
from buzz.widgets.import_url_dialog import ImportURLDialog
|
||||
from buzz.widgets.main_window_toolbar import MainWindowToolbar
|
||||
from buzz.widgets.menu_bar import MenuBar
|
||||
from buzz.widgets.snap_notice import SnapNotice
|
||||
from buzz.widgets.preferences_dialog.models.preferences import Preferences
|
||||
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
|
||||
from buzz.widgets.transcription_task_folder_watcher import (
|
||||
TranscriptionTaskFolderWatcher,
|
||||
SUPPORTED_EXTENSIONS,
|
||||
)
|
||||
from buzz.widgets.transcription_tasks_table_widget import (
|
||||
TranscriptionTasksTableWidget,
|
||||
|
|
@ -55,13 +52,14 @@ from buzz.widgets.transcription_viewer.transcription_viewer_widget import (
|
|||
|
||||
class MainWindow(QMainWindow):
|
||||
table_widget: TranscriptionTasksTableWidget
|
||||
transcriptions_updated = pyqtSignal(UUID)
|
||||
|
||||
def __init__(self, transcription_service: TranscriptionService):
|
||||
super().__init__(flags=Qt.WindowType.Window)
|
||||
|
||||
self.setWindowTitle(APP_NAME)
|
||||
self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
|
||||
self.setBaseSize(1240, 600)
|
||||
self.resize(1240, 600)
|
||||
|
||||
self.setAcceptDrops(True)
|
||||
|
||||
|
|
@ -72,16 +70,10 @@ class MainWindow(QMainWindow):
|
|||
self.quit_on_complete = False
|
||||
self.transcription_service = transcription_service
|
||||
|
||||
#update checker
|
||||
self._update_info: Optional[UpdateInfo] = None
|
||||
|
||||
self.toolbar = MainWindowToolbar(shortcuts=self.shortcuts, parent=self)
|
||||
self.toolbar.new_transcription_action_triggered.connect(
|
||||
self.on_new_transcription_action_triggered
|
||||
)
|
||||
self.toolbar.new_url_transcription_action_triggered.connect(
|
||||
self.on_new_url_transcription_action_triggered
|
||||
)
|
||||
self.toolbar.open_transcript_action_triggered.connect(
|
||||
self.open_transcript_viewer
|
||||
)
|
||||
|
|
@ -92,7 +84,6 @@ class MainWindow(QMainWindow):
|
|||
self.on_stop_transcription_action_triggered
|
||||
)
|
||||
self.addToolBar(self.toolbar)
|
||||
self.toolbar.update_action_triggered.connect(self.on_update_action_triggered)
|
||||
self.setUnifiedTitleAndToolBarOnMac(True)
|
||||
|
||||
self.preferences = self.load_preferences(settings=self.settings)
|
||||
|
|
@ -107,9 +98,6 @@ class MainWindow(QMainWindow):
|
|||
self.menu_bar.import_url_action_triggered.connect(
|
||||
self.on_new_url_transcription_action_triggered
|
||||
)
|
||||
self.menu_bar.import_folder_action_triggered.connect(
|
||||
self.on_import_folder_action_triggered
|
||||
)
|
||||
self.menu_bar.shortcuts_changed.connect(self.on_shortcuts_changed)
|
||||
self.menu_bar.openai_api_key_changed.connect(
|
||||
self.on_openai_access_token_changed
|
||||
|
|
@ -118,16 +106,11 @@ class MainWindow(QMainWindow):
|
|||
self.setMenuBar(self.menu_bar)
|
||||
|
||||
self.table_widget = TranscriptionTasksTableWidget(self)
|
||||
self.table_widget.transcription_service = self.transcription_service
|
||||
self.table_widget.doubleClicked.connect(self.on_table_double_clicked)
|
||||
self.table_widget.return_clicked.connect(self.open_transcript_viewer)
|
||||
self.table_widget.delete_requested.connect(self.on_clear_history_action_triggered)
|
||||
self.table_widget.selectionModel().selectionChanged.connect(
|
||||
self.on_table_selection_changed
|
||||
)
|
||||
self.transcriptions_updated.connect(
|
||||
self.on_transcriptions_updated
|
||||
)
|
||||
|
||||
self.setCentralWidget(self.table_widget)
|
||||
|
||||
|
|
@ -162,8 +145,23 @@ class MainWindow(QMainWindow):
|
|||
|
||||
self.transcription_viewer_widget = None
|
||||
|
||||
#Initialize and run update checker
|
||||
self._init_update_checker()
|
||||
if os.environ.get('SNAP_NAME', '') == 'buzz':
|
||||
logging.debug("Running in a snap environment")
|
||||
self.check_linux_permissions()
|
||||
|
||||
def check_linux_permissions(self):
|
||||
devices = sounddevice.query_devices()
|
||||
input_devices = [device for device in devices if device['max_input_channels'] > 0]
|
||||
|
||||
if len(input_devices) == 0:
|
||||
snap_notice = SnapNotice(self)
|
||||
snap_notice.show()
|
||||
|
||||
try:
|
||||
_ = keyring.get_password(APP_NAME, username="random")
|
||||
except Exception:
|
||||
snap_notice = SnapNotice(self)
|
||||
snap_notice.show()
|
||||
|
||||
def on_preferences_changed(self, preferences: Preferences):
|
||||
self.preferences = preferences
|
||||
|
|
@ -268,20 +266,6 @@ class MainWindow(QMainWindow):
|
|||
if url is not None:
|
||||
self.open_file_transcriber_widget(url=url)
|
||||
|
||||
def on_import_folder_action_triggered(self):
|
||||
folder = QFileDialog.getExistingDirectory(self, _("Select folder"))
|
||||
if not folder:
|
||||
return
|
||||
file_paths = []
|
||||
for dirpath, _dirs, filenames in os.walk(folder):
|
||||
for filename in filenames:
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext in SUPPORTED_EXTENSIONS:
|
||||
file_paths.append(os.path.join(dirpath, filename))
|
||||
if not file_paths:
|
||||
return
|
||||
self.open_file_transcriber_widget(file_paths)
|
||||
|
||||
def open_file_transcriber_widget(
|
||||
self, file_paths: Optional[List[str]] = None, url: Optional[str] = None
|
||||
):
|
||||
|
|
@ -384,7 +368,6 @@ class MainWindow(QMainWindow):
|
|||
shortcuts=self.shortcuts,
|
||||
parent=self,
|
||||
flags=Qt.WindowType.Window,
|
||||
transcriptions_updated_signal=self.transcriptions_updated,
|
||||
)
|
||||
self.transcription_viewer_widget.show()
|
||||
|
||||
|
|
@ -393,9 +376,6 @@ class MainWindow(QMainWindow):
|
|||
self.table_widget.refresh_all()
|
||||
self.transcriber_worker.add_task(task)
|
||||
|
||||
def on_transcriptions_updated(self):
|
||||
self.table_widget.refresh_all()
|
||||
|
||||
def on_task_started(self, task: FileTranscriptionTask):
|
||||
self.transcription_service.update_transcription_as_started(task.uid)
|
||||
self.table_widget.refresh_row(task.uid)
|
||||
|
|
@ -411,14 +391,6 @@ class MainWindow(QMainWindow):
|
|||
pass
|
||||
|
||||
def on_task_completed(self, task: FileTranscriptionTask, segments: List[Segment]):
|
||||
# Update file path in database only for URL imports where file is downloaded
|
||||
if task.source == FileTranscriptionTask.Source.URL_IMPORT and task.file_path:
|
||||
logging.debug(f"Updating transcription file path: {task.file_path}")
|
||||
# Use the file basename (video title) as the display name
|
||||
basename = os.path.basename(task.file_path)
|
||||
name = os.path.splitext(basename)[0] # Remove .wav extension
|
||||
self.transcription_service.update_transcription_file_and_name(task.uid, task.file_path, name)
|
||||
|
||||
self.transcription_service.update_transcription_as_completed(task.uid, segments)
|
||||
self.table_widget.refresh_row(task.uid)
|
||||
|
||||
|
|
@ -444,48 +416,14 @@ class MainWindow(QMainWindow):
|
|||
|
||||
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
|
||||
self.save_geometry()
|
||||
self.settings.settings.sync()
|
||||
|
||||
if self.folder_watcher:
|
||||
try:
|
||||
self.folder_watcher.task_found.disconnect()
|
||||
if len(self.folder_watcher.directories()) > 0:
|
||||
self.folder_watcher.removePaths(self.folder_watcher.directories())
|
||||
except Exception as e:
|
||||
logging.warning(f"Error cleaning up folder watcher: {e}")
|
||||
|
||||
try:
|
||||
self.transcriber_worker.task_started.disconnect()
|
||||
self.transcriber_worker.task_progress.disconnect()
|
||||
self.transcriber_worker.task_download_progress.disconnect()
|
||||
self.transcriber_worker.task_error.disconnect()
|
||||
self.transcriber_worker.task_completed.disconnect()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error disconnecting signals: {e}")
|
||||
|
||||
self.transcriber_worker.stop()
|
||||
self.transcriber_thread.quit()
|
||||
|
||||
if self.transcriber_thread.isRunning():
|
||||
if not self.transcriber_thread.wait(10000):
|
||||
logging.warning("Transcriber thread did not finish within 10s timeout, terminating")
|
||||
self.transcriber_thread.terminate()
|
||||
if not self.transcriber_thread.wait(2000):
|
||||
logging.error("Transcriber thread could not be terminated")
|
||||
self.transcriber_thread.wait()
|
||||
|
||||
if self.transcription_viewer_widget is not None:
|
||||
self.transcription_viewer_widget.close()
|
||||
|
||||
try:
|
||||
from buzz.widgets.application import Application
|
||||
app = Application.instance()
|
||||
if app and hasattr(app, 'close_database'):
|
||||
app.close_database()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error closing database: {e}")
|
||||
|
||||
logging.debug("MainWindow closeEvent completed")
|
||||
|
||||
super().closeEvent(event)
|
||||
|
||||
def save_geometry(self):
|
||||
|
|
@ -498,31 +436,4 @@ class MainWindow(QMainWindow):
|
|||
geometry = self.settings.settings.value("geometry")
|
||||
if geometry is not None:
|
||||
self.restoreGeometry(geometry)
|
||||
else:
|
||||
self.setBaseSize(1240, 600)
|
||||
self.resize(1240, 600)
|
||||
self.settings.end_group()
|
||||
|
||||
def _init_update_checker(self):
|
||||
"""Initializes and runs the update checker."""
|
||||
self.update_checker = UpdateChecker(settings=self.settings, parent=self)
|
||||
self.update_checker.update_available.connect(self._on_update_available)
|
||||
|
||||
# Check for updates on startup
|
||||
self.update_checker.check_for_updates()
|
||||
|
||||
def _on_update_available(self, update_info: UpdateInfo):
|
||||
"""Called when an update is available."""
|
||||
self._update_info = update_info
|
||||
self.toolbar.set_update_available(True)
|
||||
|
||||
def on_update_action_triggered(self):
|
||||
"""Called when user clicks the update action in toolbar."""
|
||||
if self._update_info is None:
|
||||
return
|
||||
|
||||
dialog = UpdateDialog(
|
||||
update_info=self._update_info,
|
||||
parent=self
|
||||
)
|
||||
dialog.exec()
|
||||
|
|
@ -12,11 +12,9 @@ from buzz.widgets.icon import Icon
|
|||
from buzz.widgets.icon import (
|
||||
RECORD_ICON_PATH,
|
||||
ADD_ICON_PATH,
|
||||
URL_ICON_PATH,
|
||||
EXPAND_ICON_PATH,
|
||||
CANCEL_ICON_PATH,
|
||||
TRASH_ICON_PATH,
|
||||
UPDATE_ICON_PATH,
|
||||
)
|
||||
from buzz.widgets.recording_transcriber_widget import RecordingTranscriberWidget
|
||||
from buzz.widgets.toolbar import ToolBar
|
||||
|
|
@ -24,10 +22,8 @@ from buzz.widgets.toolbar import ToolBar
|
|||
|
||||
class MainWindowToolbar(ToolBar):
|
||||
new_transcription_action_triggered: pyqtSignal
|
||||
new_url_transcription_action_triggered: pyqtSignal
|
||||
open_transcript_action_triggered: pyqtSignal
|
||||
clear_history_action_triggered: pyqtSignal
|
||||
update_action_triggered: pyqtSignal
|
||||
ICON_LIGHT_THEME_BACKGROUND = "#555"
|
||||
ICON_DARK_THEME_BACKGROUND = "#AAA"
|
||||
|
||||
|
|
@ -39,22 +35,13 @@ class MainWindowToolbar(ToolBar):
|
|||
self.record_action = Action(Icon(RECORD_ICON_PATH, self), _("Record"), self)
|
||||
self.record_action.triggered.connect(self.on_record_action_triggered)
|
||||
|
||||
# Note: Changes to "New File Transcription" need to be reflected
|
||||
# also in tests/widgets/main_window_test.py
|
||||
self.new_transcription_action = Action(
|
||||
Icon(ADD_ICON_PATH, self), _("New File Transcription"), self
|
||||
Icon(ADD_ICON_PATH, self), _("New Transcription"), self
|
||||
)
|
||||
self.new_transcription_action_triggered = (
|
||||
self.new_transcription_action.triggered
|
||||
)
|
||||
|
||||
self.new_url_transcription_action = Action(
|
||||
Icon(URL_ICON_PATH, self), _("New URL Transcription"), self
|
||||
)
|
||||
self.new_url_transcription_action_triggered = (
|
||||
self.new_url_transcription_action.triggered
|
||||
)
|
||||
|
||||
self.open_transcript_action = Action(
|
||||
Icon(EXPAND_ICON_PATH, self), _("Open Transcript"), self
|
||||
)
|
||||
|
|
@ -72,13 +59,6 @@ class MainWindowToolbar(ToolBar):
|
|||
self.clear_history_action = Action(
|
||||
Icon(TRASH_ICON_PATH, self), _("Clear History"), self
|
||||
)
|
||||
|
||||
self.update_action = Action(
|
||||
Icon(UPDATE_ICON_PATH, self), _("Update Available"), self
|
||||
)
|
||||
self.update_action_triggered = self.update_action.triggered
|
||||
self.update_action.setVisible(False)
|
||||
|
||||
self.clear_history_action_triggered = self.clear_history_action.triggered
|
||||
self.clear_history_action.setDisabled(True)
|
||||
|
||||
|
|
@ -89,16 +69,11 @@ class MainWindowToolbar(ToolBar):
|
|||
self.addActions(
|
||||
[
|
||||
self.new_transcription_action,
|
||||
self.new_url_transcription_action,
|
||||
self.open_transcript_action,
|
||||
self.stop_transcription_action,
|
||||
self.clear_history_action,
|
||||
]
|
||||
)
|
||||
|
||||
self.addSeparator()
|
||||
self.addAction(self.update_action)
|
||||
|
||||
self.setMovable(False)
|
||||
self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
|
||||
|
||||
|
|
@ -106,6 +81,9 @@ class MainWindowToolbar(ToolBar):
|
|||
self.record_action.setShortcut(
|
||||
QKeySequence.fromString(self.shortcuts.get(Shortcut.OPEN_RECORD_WINDOW))
|
||||
)
|
||||
self.new_transcription_action.setShortcut(
|
||||
QKeySequence.fromString(self.shortcuts.get(Shortcut.OPEN_IMPORT_WINDOW))
|
||||
)
|
||||
self.stop_transcription_action.setShortcut(
|
||||
QKeySequence.fromString(self.shortcuts.get(Shortcut.STOP_TRANSCRIPTION))
|
||||
)
|
||||
|
|
@ -127,7 +105,3 @@ class MainWindowToolbar(ToolBar):
|
|||
|
||||
def set_clear_history_action_enabled(self, enabled: bool):
|
||||
self.clear_history_action.setEnabled(enabled)
|
||||
|
||||
def set_update_available(self, available: bool):
|
||||
"""Shows or hides the update action in the toolbar."""
|
||||
self.update_action.setVisible(available)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import platform
|
||||
import webbrowser
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -20,7 +19,6 @@ from buzz.widgets.preferences_dialog.preferences_dialog import (
|
|||
class MenuBar(QMenuBar):
|
||||
import_action_triggered = pyqtSignal()
|
||||
import_url_action_triggered = pyqtSignal()
|
||||
import_folder_action_triggered = pyqtSignal()
|
||||
shortcuts_changed = pyqtSignal()
|
||||
openai_api_key_changed = pyqtSignal(str)
|
||||
preferences_changed = pyqtSignal(Preferences)
|
||||
|
|
@ -43,17 +41,12 @@ class MenuBar(QMenuBar):
|
|||
self.import_url_action = QAction(_("Import URL..."), self)
|
||||
self.import_url_action.triggered.connect(self.import_url_action_triggered)
|
||||
|
||||
self.import_folder_action = QAction(_("Import Folder..."), self)
|
||||
self.import_folder_action.triggered.connect(self.import_folder_action_triggered)
|
||||
|
||||
about_label = _("About")
|
||||
about_action = QAction(f'{about_label} {APP_NAME}', self)
|
||||
about_action.triggered.connect(self.on_about_action_triggered)
|
||||
about_action.setMenuRole(QAction.MenuRole.AboutRole)
|
||||
|
||||
self.preferences_action = QAction(_("Preferences..."), self)
|
||||
self.preferences_action.triggered.connect(self.on_preferences_action_triggered)
|
||||
self.preferences_action.setMenuRole(QAction.MenuRole.PreferencesRole)
|
||||
|
||||
help_label = _("Help")
|
||||
help_action = QAction(f'{help_label}', self)
|
||||
|
|
@ -64,10 +57,8 @@ class MenuBar(QMenuBar):
|
|||
file_menu = self.addMenu(_("File"))
|
||||
file_menu.addAction(self.import_action)
|
||||
file_menu.addAction(self.import_url_action)
|
||||
file_menu.addAction(self.import_folder_action)
|
||||
|
||||
help_menu_title = _("Help") + ("\u200B" if platform.system() == "Darwin" else "")
|
||||
help_menu = self.addMenu(help_menu_title)
|
||||
help_menu = self.addMenu(_("Help"))
|
||||
help_menu.addAction(about_action)
|
||||
help_menu.addAction(help_action)
|
||||
help_menu.addAction(self.preferences_action)
|
||||
|
|
|
|||
|
|
@ -18,9 +18,8 @@ class ModelDownloadProgressDialog(QProgressDialog):
|
|||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.setMinimumWidth(350)
|
||||
self.cancelable = (
|
||||
model_type == ModelType.WHISPER
|
||||
model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP
|
||||
)
|
||||
self.start_time = datetime.now()
|
||||
self.setRange(0, 100)
|
||||
|
|
|
|||
|
|
@ -44,16 +44,11 @@ class FolderWatchPreferencesWidget(QWidget):
|
|||
checkbox.setObjectName("EnableFolderWatchCheckbox")
|
||||
checkbox.stateChanged.connect(self.on_enable_changed)
|
||||
|
||||
delete_checkbox = QCheckBox(_("Delete processed files"))
|
||||
delete_checkbox.setChecked(config.delete_processed_files)
|
||||
delete_checkbox.setObjectName("DeleteProcessedFilesCheckbox")
|
||||
delete_checkbox.stateChanged.connect(self.on_delete_processed_files_changed)
|
||||
input_folder_browse_button = QPushButton(_("Browse"))
|
||||
input_folder_browse_button.clicked.connect(self.on_click_browse_input_folder)
|
||||
|
||||
self.input_folder_browse_button = QPushButton(_("Browse"))
|
||||
self.input_folder_browse_button.clicked.connect(self.on_click_browse_input_folder)
|
||||
|
||||
self.output_folder_browse_button = QPushButton(_("Browse"))
|
||||
self.output_folder_browse_button.clicked.connect(self.on_click_browse_output_folder)
|
||||
output_folder_browse_button = QPushButton(_("Browse"))
|
||||
output_folder_browse_button.clicked.connect(self.on_click_browse_output_folder)
|
||||
|
||||
input_folder_row = QHBoxLayout()
|
||||
self.input_folder_line_edit = LineEdit(config.input_directory, self)
|
||||
|
|
@ -62,7 +57,7 @@ class FolderWatchPreferencesWidget(QWidget):
|
|||
self.input_folder_line_edit.setObjectName("InputFolderLineEdit")
|
||||
|
||||
input_folder_row.addWidget(self.input_folder_line_edit)
|
||||
input_folder_row.addWidget(self.input_folder_browse_button)
|
||||
input_folder_row.addWidget(input_folder_browse_button)
|
||||
|
||||
output_folder_row = QHBoxLayout()
|
||||
self.output_folder_line_edit = LineEdit(config.output_directory, self)
|
||||
|
|
@ -71,7 +66,7 @@ class FolderWatchPreferencesWidget(QWidget):
|
|||
self.output_folder_line_edit.setObjectName("OutputFolderLineEdit")
|
||||
|
||||
output_folder_row.addWidget(self.output_folder_line_edit)
|
||||
output_folder_row.addWidget(self.output_folder_browse_button)
|
||||
output_folder_row.addWidget(output_folder_browse_button)
|
||||
|
||||
openai_access_token = get_password(Key.OPENAI_API_KEY)
|
||||
(
|
||||
|
|
@ -82,17 +77,15 @@ class FolderWatchPreferencesWidget(QWidget):
|
|||
file_paths=[],
|
||||
)
|
||||
|
||||
self.transcription_form_widget = FileTranscriptionFormWidget(
|
||||
transcription_form_widget = FileTranscriptionFormWidget(
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
parent=self,
|
||||
)
|
||||
self.transcription_form_widget.transcription_options_changed.connect(
|
||||
transcription_form_widget.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed
|
||||
)
|
||||
|
||||
self.delete_checkbox = delete_checkbox
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
folders_form_layout = QFormLayout()
|
||||
|
|
@ -100,17 +93,14 @@ class FolderWatchPreferencesWidget(QWidget):
|
|||
folders_form_layout.addRow("", checkbox)
|
||||
folders_form_layout.addRow(_("Input folder"), input_folder_row)
|
||||
folders_form_layout.addRow(_("Output folder"), output_folder_row)
|
||||
folders_form_layout.addRow("", delete_checkbox)
|
||||
folders_form_layout.addWidget(self.transcription_form_widget)
|
||||
folders_form_layout.addWidget(transcription_form_widget)
|
||||
|
||||
layout.addLayout(folders_form_layout)
|
||||
layout.addWidget(self.transcription_form_widget)
|
||||
layout.addWidget(transcription_form_widget)
|
||||
layout.addStretch()
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
self._set_settings_enabled(config.enabled)
|
||||
|
||||
def on_click_browse_input_folder(self):
|
||||
folder = QFileDialog.getExistingDirectory(self, _("Select Input Folder"))
|
||||
self.input_folder_line_edit.setText(folder)
|
||||
|
|
@ -129,22 +119,8 @@ class FolderWatchPreferencesWidget(QWidget):
|
|||
self.config.output_directory = folder
|
||||
self.config_changed.emit(self.config)
|
||||
|
||||
def _set_settings_enabled(self, enabled: bool):
|
||||
self.input_folder_line_edit.setEnabled(enabled)
|
||||
self.input_folder_browse_button.setEnabled(enabled)
|
||||
self.output_folder_line_edit.setEnabled(enabled)
|
||||
self.output_folder_browse_button.setEnabled(enabled)
|
||||
self.delete_checkbox.setEnabled(enabled)
|
||||
self.transcription_form_widget.setEnabled(enabled)
|
||||
|
||||
def on_enable_changed(self, state: int):
|
||||
enabled = state == 2
|
||||
self.config.enabled = enabled
|
||||
self._set_settings_enabled(enabled)
|
||||
self.config_changed.emit(self.config)
|
||||
|
||||
def on_delete_processed_files_changed(self, state: int):
|
||||
self.config.delete_processed_files = state == 2
|
||||
self.config.enabled = state == 2
|
||||
self.config_changed.emit(self.config)
|
||||
|
||||
def on_transcription_options_changed(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import requests
|
|||
from typing import Optional
|
||||
from platformdirs import user_documents_dir
|
||||
|
||||
from PyQt6.QtCore import QRunnable, QObject, pyqtSignal, QThreadPool, QLocale
|
||||
from PyQt6.QtCore import QRunnable, QObject, pyqtSignal, QThreadPool
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QFormLayout,
|
||||
|
|
@ -15,10 +15,7 @@ from PyQt6.QtWidgets import (
|
|||
QFileDialog,
|
||||
QSpinBox,
|
||||
QComboBox,
|
||||
QLabel,
|
||||
QSizePolicy,
|
||||
)
|
||||
from PyQt6.QtGui import QIcon
|
||||
from openai import AuthenticationError, OpenAI
|
||||
|
||||
from buzz.settings.settings import Settings
|
||||
|
|
@ -26,27 +23,9 @@ from buzz.store.keyring_store import get_password, Key
|
|||
from buzz.widgets.line_edit import LineEdit
|
||||
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
|
||||
from buzz.locale import _
|
||||
from buzz.widgets.icon import INFO_ICON_PATH
|
||||
from buzz.settings.recording_transcriber_mode import RecordingTranscriberMode
|
||||
|
||||
BASE64_PATTERN = re.compile(r'^[A-Za-z0-9+/=_-]*$')
|
||||
|
||||
ui_locales = {
|
||||
"en_US": _("English"),
|
||||
"ca_ES": _("Catalan"),
|
||||
"da_DK": _("Danish"),
|
||||
"nl": _("Dutch"),
|
||||
"de_DE": _("German"),
|
||||
"es_ES": _("Spanish"),
|
||||
"it_IT": _("Italian"),
|
||||
"ja_JP": _("Japanese"),
|
||||
"lv_LV": _("Latvian"),
|
||||
"pl_PL": _("Polish"),
|
||||
"pt_BR": _("Portuguese (Brazil)"),
|
||||
"uk_UA": _("Ukrainian"),
|
||||
"zh_CN": _("Chinese (Simplified)"),
|
||||
"zh_TW": _("Chinese (Traditional)")
|
||||
}
|
||||
base64_pattern = re.compile(r'^[A-Za-z0-9+/=]*$')
|
||||
|
||||
|
||||
class GeneralPreferencesWidget(QWidget):
|
||||
|
|
@ -64,31 +43,6 @@ class GeneralPreferencesWidget(QWidget):
|
|||
|
||||
layout = QFormLayout(self)
|
||||
|
||||
self.ui_language_combo_box = QComboBox(self)
|
||||
self.ui_language_combo_box.addItems(ui_locales.values())
|
||||
system_locale = self.settings.value(Settings.Key.UI_LOCALE, QLocale().name())
|
||||
locale_index = 0
|
||||
for i, (code, language) in enumerate(ui_locales.items()):
|
||||
if code == system_locale:
|
||||
locale_index = i
|
||||
break
|
||||
self.ui_language_combo_box.setCurrentIndex(locale_index)
|
||||
self.ui_language_combo_box.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
|
||||
self.ui_language_combo_box.currentIndexChanged.connect(self.on_language_changed)
|
||||
|
||||
self.ui_locale_layout = QHBoxLayout()
|
||||
self.ui_locale_layout.setContentsMargins(0, 0, 0, 0)
|
||||
self.ui_locale_layout.setSpacing(0)
|
||||
self.ui_locale_layout.addWidget(self.ui_language_combo_box)
|
||||
|
||||
self.load_note_tooltip_icon = QLabel()
|
||||
self.load_note_tooltip_icon.setPixmap(QIcon(INFO_ICON_PATH).pixmap(23, 23))
|
||||
self.load_note_tooltip_icon.setToolTip(_("Restart required!"))
|
||||
self.load_note_tooltip_icon.setVisible(False)
|
||||
self.ui_locale_layout.addWidget(self.load_note_tooltip_icon)
|
||||
|
||||
layout.addRow(_("Ui Language"), self.ui_locale_layout)
|
||||
|
||||
self.font_size_spin_box = QSpinBox(self)
|
||||
self.font_size_spin_box.setMinimum(8)
|
||||
self.font_size_spin_box.setMaximum(32)
|
||||
|
|
@ -125,18 +79,6 @@ class GeneralPreferencesWidget(QWidget):
|
|||
self.custom_openai_base_url_line_edit.setPlaceholderText("https://api.openai.com/v1")
|
||||
layout.addRow(_("OpenAI base url"), self.custom_openai_base_url_line_edit)
|
||||
|
||||
self.openai_api_model = self.settings.value(
|
||||
key=Settings.Key.OPENAI_API_MODEL, default_value="whisper-1"
|
||||
)
|
||||
|
||||
self.openai_api_model_line_edit = LineEdit(self.openai_api_model, self)
|
||||
self.openai_api_model_line_edit.textChanged.connect(
|
||||
self.on_openai_api_model_changed
|
||||
)
|
||||
self.openai_api_model_line_edit.setMinimumWidth(200)
|
||||
self.openai_api_model_line_edit.setPlaceholderText("whisper-1")
|
||||
layout.addRow(_("OpenAI API model"), self.openai_api_model_line_edit)
|
||||
|
||||
default_export_file_name = self.settings.get_default_export_file_template()
|
||||
|
||||
default_export_file_name_line_edit = LineEdit(default_export_file_name, self)
|
||||
|
|
@ -188,39 +130,6 @@ class GeneralPreferencesWidget(QWidget):
|
|||
|
||||
layout.addRow(_("Live recording mode"), self.recording_transcriber_mode)
|
||||
|
||||
export_note_label = QLabel(
|
||||
_("Note: Live recording export settings will be moved to the Advanced Settings in the Live Recording screen in a future version."),
|
||||
self,
|
||||
)
|
||||
export_note_label.setWordWrap(True)
|
||||
export_note_label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
|
||||
layout.addRow("", export_note_label)
|
||||
|
||||
self.reduce_gpu_memory_enabled = self.settings.value(
|
||||
key=Settings.Key.REDUCE_GPU_MEMORY, default_value=False
|
||||
)
|
||||
|
||||
self.reduce_gpu_memory_checkbox = QCheckBox(_("Use 8-bit quantization to reduce memory usage"))
|
||||
self.reduce_gpu_memory_checkbox.setChecked(self.reduce_gpu_memory_enabled)
|
||||
self.reduce_gpu_memory_checkbox.setObjectName("ReduceGPUMemoryCheckbox")
|
||||
self.reduce_gpu_memory_checkbox.setToolTip(
|
||||
_("Applies to Huggingface and Faster Whisper models. "
|
||||
"Reduces GPU memory usage but may slightly decrease transcription quality.")
|
||||
)
|
||||
self.reduce_gpu_memory_checkbox.stateChanged.connect(self.on_reduce_gpu_memory_changed)
|
||||
layout.addRow(_("Reduce GPU RAM"), self.reduce_gpu_memory_checkbox)
|
||||
|
||||
self.force_cpu_enabled = self.settings.value(
|
||||
key=Settings.Key.FORCE_CPU, default_value=False
|
||||
)
|
||||
|
||||
self.force_cpu_checkbox = QCheckBox(_("Use only CPU and disable GPU acceleration"))
|
||||
self.force_cpu_checkbox.setChecked(self.force_cpu_enabled)
|
||||
self.force_cpu_checkbox.setObjectName("ForceCPUCheckbox")
|
||||
self.force_cpu_checkbox.setToolTip(_("Set this if larger models do not fit your GPU memory and Buzz crashes"))
|
||||
self.force_cpu_checkbox.stateChanged.connect(self.on_force_cpu_changed)
|
||||
layout.addRow(_("Disable GPU"), self.force_cpu_checkbox)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def on_default_export_file_name_changed(self, text: str):
|
||||
|
|
@ -232,7 +141,7 @@ class GeneralPreferencesWidget(QWidget):
|
|||
def on_click_test_openai_api_key_button(self):
|
||||
self.test_openai_api_key_button.setEnabled(False)
|
||||
|
||||
job = ValidateOpenAIApiKeyJob(api_key=self.openai_api_key)
|
||||
job = TestOpenAIApiKeyJob(api_key=self.openai_api_key)
|
||||
job.signals.success.connect(self.on_test_openai_api_key_success)
|
||||
job.signals.failed.connect(self.on_test_openai_api_key_failure)
|
||||
job.setAutoDelete(True)
|
||||
|
|
@ -258,19 +167,16 @@ class GeneralPreferencesWidget(QWidget):
|
|||
self.openai_api_key_changed.emit(key)
|
||||
|
||||
def on_openai_api_key_focus_out(self):
|
||||
if not BASE64_PATTERN.match(self.openai_api_key):
|
||||
if not base64_pattern.match(self.openai_api_key):
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
_("Invalid API key"),
|
||||
_("API supports only base64 characters (A-Za-z0-9+/=_-). Other characters in API key may cause errors."),
|
||||
_("API supports only base64 characters (A-Za-z0-9+/=). Other characters in API key may cause errors."),
|
||||
)
|
||||
|
||||
def on_custom_openai_base_url_changed(self, text: str):
|
||||
self.settings.set_value(Settings.Key.CUSTOM_OPENAI_BASE_URL, text)
|
||||
|
||||
def on_openai_api_model_changed(self, text: str):
|
||||
self.settings.set_value(Settings.Key.OPENAI_API_MODEL, text)
|
||||
|
||||
def on_recording_export_enable_changed(self, state: int):
|
||||
self.recording_export_enabled = state == 2
|
||||
|
||||
|
|
@ -293,14 +199,6 @@ class GeneralPreferencesWidget(QWidget):
|
|||
folder,
|
||||
)
|
||||
|
||||
def on_language_changed(self, index):
|
||||
selected_language = self.ui_language_combo_box.itemText(index)
|
||||
locale_code = next((code for code, lang in ui_locales.items() if lang == selected_language), "en_US")
|
||||
|
||||
self.load_note_tooltip_icon.setVisible(True)
|
||||
|
||||
self.settings.set_value(Settings.Key.UI_LOCALE, locale_code)
|
||||
|
||||
def on_font_size_changed(self, value):
|
||||
from buzz.widgets.application import Application
|
||||
font = self.font()
|
||||
|
|
@ -313,28 +211,7 @@ class GeneralPreferencesWidget(QWidget):
|
|||
def on_recording_transcriber_mode_changed(self, value):
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_MODE, value)
|
||||
|
||||
def on_force_cpu_changed(self, state: int):
|
||||
import os
|
||||
self.force_cpu_enabled = state == 2
|
||||
self.settings.set_value(Settings.Key.FORCE_CPU, self.force_cpu_enabled)
|
||||
|
||||
if self.force_cpu_enabled:
|
||||
os.environ["BUZZ_FORCE_CPU"] = "true"
|
||||
else:
|
||||
os.environ.pop("BUZZ_FORCE_CPU", None)
|
||||
|
||||
def on_reduce_gpu_memory_changed(self, state: int):
|
||||
import os
|
||||
self.reduce_gpu_memory_enabled = state == 2
|
||||
self.settings.set_value(Settings.Key.REDUCE_GPU_MEMORY, self.reduce_gpu_memory_enabled)
|
||||
|
||||
if self.reduce_gpu_memory_enabled:
|
||||
os.environ["BUZZ_REDUCE_GPU_MEMORY"] = "true"
|
||||
else:
|
||||
os.environ.pop("BUZZ_REDUCE_GPU_MEMORY", None)
|
||||
|
||||
|
||||
class ValidateOpenAIApiKeyJob(QRunnable):
|
||||
class TestOpenAIApiKeyJob(QRunnable):
|
||||
class Signals(QObject):
|
||||
success = pyqtSignal()
|
||||
failed = pyqtSignal(str)
|
||||
|
|
@ -376,7 +253,7 @@ class ValidateOpenAIApiKeyJob(QRunnable):
|
|||
client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=custom_openai_base_url if custom_openai_base_url else None,
|
||||
timeout=15,
|
||||
timeout=5,
|
||||
)
|
||||
client.models.list()
|
||||
self.signals.success.emit()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from buzz.model_loader import TranscriptionModel
|
|||
from buzz.transcriber.transcriber import (
|
||||
Task,
|
||||
OutputFormat,
|
||||
DEFAULT_WHISPER_TEMPERATURE,
|
||||
TranscriptionOptions,
|
||||
FileTranscriptionOptions,
|
||||
)
|
||||
|
|
@ -18,7 +19,7 @@ class FileTranscriptionPreferences:
|
|||
task: Task
|
||||
model: TranscriptionModel
|
||||
word_level_timings: bool
|
||||
extract_speech: bool
|
||||
temperature: Tuple[float, ...]
|
||||
initial_prompt: str
|
||||
enable_llm_translation: bool
|
||||
llm_prompt: str
|
||||
|
|
@ -30,7 +31,7 @@ class FileTranscriptionPreferences:
|
|||
settings.setValue("task", self.task)
|
||||
settings.setValue("model", self.model)
|
||||
settings.setValue("word_level_timings", self.word_level_timings)
|
||||
settings.setValue("extract_speech", self.extract_speech)
|
||||
settings.setValue("temperature", self.temperature)
|
||||
settings.setValue("initial_prompt", self.initial_prompt)
|
||||
settings.setValue("enable_llm_translation", self.enable_llm_translation)
|
||||
settings.setValue("llm_model", self.llm_model)
|
||||
|
|
@ -52,10 +53,7 @@ class FileTranscriptionPreferences:
|
|||
word_level_timings = False if word_level_timings_value == "false" \
|
||||
else bool(word_level_timings_value)
|
||||
|
||||
extract_speech_value = settings.value("extract_speech", False)
|
||||
extract_speech = False if extract_speech_value == "false" \
|
||||
else bool(extract_speech_value)
|
||||
|
||||
temperature = settings.value("temperature", DEFAULT_WHISPER_TEMPERATURE)
|
||||
initial_prompt = settings.value("initial_prompt", "")
|
||||
enable_llm_translation_value = settings.value("enable_llm_translation", False)
|
||||
enable_llm_translation = False if enable_llm_translation_value == "false" \
|
||||
|
|
@ -70,7 +68,7 @@ class FileTranscriptionPreferences:
|
|||
if model.model_type.is_available()
|
||||
else TranscriptionModel.default(),
|
||||
word_level_timings=word_level_timings,
|
||||
extract_speech=extract_speech,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
enable_llm_translation=enable_llm_translation,
|
||||
llm_model=llm_model,
|
||||
|
|
@ -89,12 +87,12 @@ class FileTranscriptionPreferences:
|
|||
return FileTranscriptionPreferences(
|
||||
task=transcription_options.task,
|
||||
language=transcription_options.language,
|
||||
temperature=transcription_options.temperature,
|
||||
initial_prompt=transcription_options.initial_prompt,
|
||||
enable_llm_translation=transcription_options.enable_llm_translation,
|
||||
llm_model=transcription_options.llm_model,
|
||||
llm_prompt=transcription_options.llm_prompt,
|
||||
word_level_timings=transcription_options.word_level_timings,
|
||||
extract_speech=transcription_options.extract_speech,
|
||||
model=transcription_options.model,
|
||||
output_formats=file_transcription_options.output_formats,
|
||||
)
|
||||
|
|
@ -109,12 +107,12 @@ class FileTranscriptionPreferences:
|
|||
TranscriptionOptions(
|
||||
task=self.task,
|
||||
language=self.language,
|
||||
temperature=self.temperature,
|
||||
initial_prompt=self.initial_prompt,
|
||||
enable_llm_translation=self.enable_llm_translation,
|
||||
llm_model=self.llm_model,
|
||||
llm_prompt=self.llm_prompt,
|
||||
word_level_timings=self.word_level_timings,
|
||||
extract_speech=self.extract_speech,
|
||||
model=self.model,
|
||||
openai_access_token=openai_access_token,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -13,13 +13,11 @@ class FolderWatchPreferences:
|
|||
input_directory: str
|
||||
output_directory: str
|
||||
file_transcription_options: FileTranscriptionPreferences
|
||||
delete_processed_files: bool = False
|
||||
|
||||
def save(self, settings: QSettings):
|
||||
settings.setValue("enabled", self.enabled)
|
||||
settings.setValue("input_folder", self.input_directory)
|
||||
settings.setValue("output_directory", self.output_directory)
|
||||
settings.setValue("delete_processed_files", self.delete_processed_files)
|
||||
settings.beginGroup("file_transcription_options")
|
||||
self.file_transcription_options.save(settings)
|
||||
settings.endGroup()
|
||||
|
|
@ -31,8 +29,6 @@ class FolderWatchPreferences:
|
|||
|
||||
input_folder = settings.value("input_folder", defaultValue="", type=str)
|
||||
output_folder = settings.value("output_directory", defaultValue="", type=str)
|
||||
delete_value = settings.value("delete_processed_files", False)
|
||||
delete_processed_files = False if delete_value == "false" else bool(delete_value)
|
||||
settings.beginGroup("file_transcription_options")
|
||||
file_transcription_options = FileTranscriptionPreferences.load(settings)
|
||||
settings.endGroup()
|
||||
|
|
@ -41,5 +37,4 @@ class FolderWatchPreferences:
|
|||
input_directory=input_folder,
|
||||
output_directory=output_folder,
|
||||
file_transcription_options=file_transcription_options,
|
||||
delete_processed_files=delete_processed_files,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import Qt, QThreadPool, QLocale
|
||||
from PyQt6.QtCore import Qt, QThreadPool
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QFormLayout,
|
||||
|
|
@ -40,7 +40,6 @@ class ModelsPreferencesWidget(QWidget):
|
|||
super().__init__(parent)
|
||||
|
||||
self.settings = Settings()
|
||||
self.ui_locale = self.settings.value(Settings.Key.UI_LOCALE, QLocale().name())
|
||||
self.model_downloader: Optional[ModelDownloader] = None
|
||||
|
||||
model_types = [
|
||||
|
|
@ -183,11 +182,6 @@ class ModelsPreferencesWidget(QWidget):
|
|||
model_size == WhisperModelSize.CUSTOM):
|
||||
continue
|
||||
|
||||
# Skip LUMII size for all non Latvians
|
||||
if (model_size == WhisperModelSize.LUMII and
|
||||
(self.model.model_type != ModelType.WHISPER_CPP or self.ui_locale != "lv_LV")):
|
||||
continue
|
||||
|
||||
model = TranscriptionModel(
|
||||
model_type=self.model.model_type,
|
||||
whisper_model_size=WhisperModelSize(model_size),
|
||||
|
|
@ -275,8 +269,7 @@ class ModelsPreferencesWidget(QWidget):
|
|||
QMessageBox.warning(self, _("Error"), f"{download_failed_label}: {error}")
|
||||
|
||||
def on_download_progress(self, progress: tuple):
|
||||
if progress[1] != 0:
|
||||
self.progress_dialog.set_value(float(progress[0]) / progress[1])
|
||||
self.progress_dialog.set_value(float(progress[0]) / progress[1])
|
||||
|
||||
def on_progress_dialog_canceled(self):
|
||||
self.model_downloader.cancel()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton
|
|||
from buzz.locale import _
|
||||
from buzz.settings.shortcut import Shortcut
|
||||
from buzz.settings.shortcuts import Shortcuts
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
from buzz.widgets.sequence_edit import SequenceEdit
|
||||
|
||||
|
||||
|
|
@ -20,10 +19,8 @@ class ShortcutsEditorPreferencesWidget(QWidget):
|
|||
self.shortcuts = shortcuts
|
||||
|
||||
self.layout = QFormLayout(self)
|
||||
_field_height = LineEdit().sizeHint().height()
|
||||
for shortcut in Shortcut:
|
||||
sequence_edit = SequenceEdit(shortcuts.get(shortcut), self)
|
||||
sequence_edit.setFixedHeight(_field_height)
|
||||
sequence_edit.keySequenceChanged.connect(
|
||||
self.get_key_sequence_changed(shortcut)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,189 +0,0 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from PyQt6.QtCore import Qt
|
||||
from PyQt6.QtGui import QTextCursor
|
||||
from PyQt6.QtWidgets import QWidget, QVBoxLayout, QTextBrowser
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.settings.settings import Settings
|
||||
|
||||
import os
|
||||
|
||||
class PresentationWindow(QWidget):
|
||||
"""Window for displaying live transcripts in presentation mode"""
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.settings = Settings()
|
||||
self._current_transcript = ""
|
||||
self._current_translation = ""
|
||||
self.window_style = ""
|
||||
self.setWindowTitle(_("Live Transcript Presentation"))
|
||||
self.setWindowFlag(Qt.WindowType.Window)
|
||||
|
||||
# Window size
|
||||
self.resize(800, 600)
|
||||
|
||||
# Create layout
|
||||
layout = QVBoxLayout(self)
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
layout.setSpacing(0)
|
||||
|
||||
# Text display widget
|
||||
self.transcript_display = QTextBrowser(self)
|
||||
self.transcript_display.setReadOnly(True)
|
||||
|
||||
# Translation display (hidden first)
|
||||
self.translation_display = QTextBrowser(self)
|
||||
self.translation_display.setReadOnly(True)
|
||||
self.translation_display.hide()
|
||||
|
||||
# Add to layout
|
||||
layout.addWidget(self.transcript_display)
|
||||
layout.addWidget(self.translation_display)
|
||||
|
||||
self.load_settings()
|
||||
|
||||
def load_settings(self):
|
||||
"""Load and apply saved presentation settings"""
|
||||
theme = self.settings.value(
|
||||
Settings.Key.PRESENTATION_WINDOW_THEME,
|
||||
"light"
|
||||
)
|
||||
|
||||
# Load text size
|
||||
text_size = self.settings.value(
|
||||
Settings.Key.PRESENTATION_WINDOW_TEXT_SIZE,
|
||||
24,
|
||||
int
|
||||
)
|
||||
|
||||
# Load colors based on theme
|
||||
if theme == "light":
|
||||
text_color = "#000000"
|
||||
bg_color = "#FFFFFF"
|
||||
elif theme == "dark":
|
||||
text_color = "#FFFFFF"
|
||||
bg_color = "#000000"
|
||||
else:
|
||||
text_color = self.settings.value(
|
||||
Settings.Key.PRESENTATION_WINDOW_TEXT_COLOR,
|
||||
"#000000"
|
||||
)
|
||||
|
||||
bg_color = self.settings.value(
|
||||
Settings.Key.PRESENTATION_WINDOW_BACKGROUND_COLOR,
|
||||
"#FFFFFF"
|
||||
)
|
||||
|
||||
self.apply_styling(text_color, bg_color, text_size)
|
||||
|
||||
# Refresh content with new styling
|
||||
if self._current_transcript:
|
||||
self.update_transcripts(self._current_transcript)
|
||||
if self._current_translation:
|
||||
self.update_translations(self._current_translation)
|
||||
|
||||
def apply_styling(self, text_color: str, bg_color: str, text_size: int):
|
||||
"""Apply text color, background color and font size"""
|
||||
|
||||
# Load custom CSS if it exists
|
||||
css_file_path = self.get_css_file_path()
|
||||
|
||||
if os.path.exists(css_file_path):
|
||||
try:
|
||||
with open(css_file_path, "r", encoding="utf-8") as f:
|
||||
self.window_style = f.read()
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to load custom CSS: {e}")
|
||||
else:
|
||||
self.window_style = f"""
|
||||
body {{
|
||||
color: {text_color};
|
||||
background-color: {bg_color};
|
||||
font-size: {text_size}pt;
|
||||
font-family: Arial, sans-serif;
|
||||
padding: 0;
|
||||
margin: 20px;
|
||||
}}
|
||||
"""
|
||||
|
||||
def update_transcripts(self, text: str):
|
||||
"""Update the transcript display with new text"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
self._current_transcript = text
|
||||
escaped_text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
html_text = escaped_text.replace("\n", "<br>")
|
||||
|
||||
html_content = f"""
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
{self.window_style}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{html_text}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
self.transcript_display.setHtml(html_content)
|
||||
self.transcript_display.moveCursor(QTextCursor.MoveOperation.End)
|
||||
|
||||
def update_translations(self, text: str):
|
||||
"""Update the translation display with new text"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
self._current_translation = text
|
||||
self.translation_display.show()
|
||||
|
||||
escaped_text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
html_text = escaped_text.replace("\n", "<br>")
|
||||
|
||||
html_content = f"""
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
{self.window_style}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{html_text}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
self.translation_display.setHtml(html_content)
|
||||
self.translation_display.moveCursor(QTextCursor.MoveOperation.End)
|
||||
|
||||
def toggle_fullscreen(self):
|
||||
"""Toggle fullscreen mode"""
|
||||
if self.isFullScreen():
|
||||
self.showNormal()
|
||||
else:
|
||||
self.showFullScreen()
|
||||
|
||||
def keyPressEvent(self, event):
|
||||
"""Handle keyboard events"""
|
||||
# ESC Key exits fullscreen
|
||||
if event.key() == Qt.Key.Key_Escape and self.isFullScreen():
|
||||
self.showNormal()
|
||||
event.accept()
|
||||
else:
|
||||
super().keyPressEvent(event)
|
||||
|
||||
|
||||
def get_css_file_path(self) -> str:
|
||||
"""Get path to custom CSS file"""
|
||||
cache_dir = user_cache_dir("Buzz")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
return os.path.join(cache_dir, "presentation_window_style.css")
|
||||
|
||||
|
||||
32
buzz/widgets/snap_notice.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from PyQt6.QtWidgets import QDialog, QVBoxLayout, QTextEdit, QLabel, QPushButton
|
||||
from buzz.locale import _
|
||||
|
||||
|
||||
class SnapNotice(QDialog):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.setWindowTitle(_("Snap permission notice"))
|
||||
|
||||
self.layout = QVBoxLayout(self)
|
||||
|
||||
self.notice_label = QLabel(_("Detected missing permissions, please check that snap permissions have been granted"))
|
||||
self.layout.addWidget(self.notice_label)
|
||||
|
||||
self.instruction_label = QLabel(_("To enable necessary permissions run the following commands in the terminal"))
|
||||
self.layout.addWidget(self.instruction_label)
|
||||
|
||||
self.text_edit = QTextEdit(self)
|
||||
self.text_edit.setPlainText(
|
||||
"sudo snap connect buzz:audio-record\n"
|
||||
"sudo snap connect buzz:password-manager-service\n"
|
||||
"sudo snap connect buzz:pulseaudio\n"
|
||||
"sudo snap connect buzz:removable-media"
|
||||
)
|
||||
self.text_edit.setReadOnly(True)
|
||||
self.text_edit.setFixedHeight(80)
|
||||
self.layout.addWidget(self.text_edit)
|
||||
|
||||
self.button = QPushButton(_("Close"), self)
|
||||
self.button.clicked.connect(self.close)
|
||||
self.layout.addWidget(self.button)
|
||||
|
|
@ -7,34 +7,23 @@ from PyQt6.QtWidgets import (
|
|||
QPlainTextEdit,
|
||||
QFormLayout,
|
||||
QLabel,
|
||||
QDoubleSpinBox,
|
||||
QLineEdit,
|
||||
QComboBox,
|
||||
QHBoxLayout,
|
||||
QPushButton,
|
||||
QSpinBox,
|
||||
QFileDialog,
|
||||
)
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.transcriber.transcriber import TranscriptionOptions
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.settings.recording_transcriber_mode import RecordingTranscriberMode
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
from buzz.widgets.transcriber.initial_prompt_text_edit import InitialPromptTextEdit
|
||||
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
|
||||
|
||||
|
||||
class AdvancedSettingsDialog(QDialog):
|
||||
transcription_options: TranscriptionOptions
|
||||
transcription_options_changed = pyqtSignal(TranscriptionOptions)
|
||||
recording_mode_changed = pyqtSignal(RecordingTranscriberMode)
|
||||
hide_unconfirmed_changed = pyqtSignal(bool)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transcription_options: TranscriptionOptions,
|
||||
parent: QWidget | None = None,
|
||||
show_recording_settings: bool = False,
|
||||
self, transcription_options: TranscriptionOptions, parent: QWidget | None = None
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
|
|
@ -42,15 +31,29 @@ class AdvancedSettingsDialog(QDialog):
|
|||
self.settings = Settings()
|
||||
|
||||
self.setWindowTitle(_("Advanced Settings"))
|
||||
self.setMinimumWidth(800)
|
||||
|
||||
layout = QFormLayout(self)
|
||||
layout.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow)
|
||||
|
||||
transcription_settings_title= _("Speech recognition settings")
|
||||
transcription_settings_title_label = QLabel(f"<h4>{transcription_settings_title}</h4>", self)
|
||||
layout.addRow("", transcription_settings_title_label)
|
||||
|
||||
default_temperature_text = ", ".join(
|
||||
[str(temp) for temp in transcription_options.temperature]
|
||||
)
|
||||
self.temperature_line_edit = LineEdit(default_temperature_text, self)
|
||||
self.temperature_line_edit.setPlaceholderText(
|
||||
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"')
|
||||
)
|
||||
self.temperature_line_edit.setMinimumWidth(170)
|
||||
self.temperature_line_edit.textChanged.connect(self.on_temperature_changed)
|
||||
self.temperature_line_edit.setValidator(TemperatureValidator(self))
|
||||
self.temperature_line_edit.setEnabled(
|
||||
transcription_options.model.model_type == ModelType.WHISPER
|
||||
)
|
||||
|
||||
layout.addRow(_("Temperature:"), self.temperature_line_edit)
|
||||
|
||||
self.initial_prompt_text_edit = InitialPromptTextEdit(
|
||||
transcription_options.initial_prompt,
|
||||
transcription_options.model.model_type,
|
||||
|
|
@ -71,160 +74,22 @@ class AdvancedSettingsDialog(QDialog):
|
|||
self.enable_llm_translation_checkbox.stateChanged.connect(self.on_enable_llm_translation_changed)
|
||||
layout.addRow("", self.enable_llm_translation_checkbox)
|
||||
|
||||
llm_model = self.transcription_options.llm_model or "gpt-4.1-mini"
|
||||
self.llm_model_line_edit = LineEdit(llm_model, self)
|
||||
self.llm_model_line_edit.textChanged.connect(self.on_llm_model_changed)
|
||||
self.llm_model_line_edit = LineEdit(self.transcription_options.llm_model, self)
|
||||
self.llm_model_line_edit.textChanged.connect(
|
||||
self.on_llm_model_changed
|
||||
)
|
||||
self.llm_model_line_edit.setMinimumWidth(170)
|
||||
self.llm_model_line_edit.setEnabled(self.transcription_options.enable_llm_translation)
|
||||
self.llm_model_label = QLabel(_("AI model:"))
|
||||
self.llm_model_label.setEnabled(self.transcription_options.enable_llm_translation)
|
||||
layout.addRow(self.llm_model_label, self.llm_model_line_edit)
|
||||
self.llm_model_line_edit.setPlaceholderText("gpt-3.5-turbo")
|
||||
layout.addRow(_("AI model:"), self.llm_model_line_edit)
|
||||
|
||||
default_llm_prompt = self.transcription_options.llm_prompt or _(
|
||||
"Please translate each text sent to you from English to Spanish. Translation will be used in an automated system, please do not add any comments or notes, just the translation."
|
||||
)
|
||||
self.llm_prompt_text_edit = QPlainTextEdit(default_llm_prompt)
|
||||
self.llm_prompt_text_edit = QPlainTextEdit(self.transcription_options.llm_prompt)
|
||||
self.llm_prompt_text_edit.setEnabled(self.transcription_options.enable_llm_translation)
|
||||
self.llm_prompt_text_edit.setPlaceholderText(_("Enter instructions for AI on how to translate..."))
|
||||
self.llm_prompt_text_edit.setMinimumWidth(170)
|
||||
self.llm_prompt_text_edit.setFixedHeight(80)
|
||||
self.llm_prompt_text_edit.setFixedHeight(115)
|
||||
self.llm_prompt_text_edit.textChanged.connect(self.on_llm_prompt_changed)
|
||||
self.llm_prompt_label = QLabel(_("Instructions for AI:"))
|
||||
self.llm_prompt_label.setEnabled(self.transcription_options.enable_llm_translation)
|
||||
layout.addRow(self.llm_prompt_label, self.llm_prompt_text_edit)
|
||||
|
||||
if show_recording_settings:
|
||||
recording_settings_title = _("Recording settings")
|
||||
recording_settings_title_label = QLabel(f"<h4>{recording_settings_title}</h4>", self)
|
||||
layout.addRow("", recording_settings_title_label)
|
||||
|
||||
self.silence_threshold_spin_box = QDoubleSpinBox(self)
|
||||
self.silence_threshold_spin_box.setRange(0.0, 1.0)
|
||||
self.silence_threshold_spin_box.setSingleStep(0.0005)
|
||||
self.silence_threshold_spin_box.setDecimals(4)
|
||||
self.silence_threshold_spin_box.setValue(transcription_options.silence_threshold)
|
||||
self.silence_threshold_spin_box.valueChanged.connect(self.on_silence_threshold_changed)
|
||||
self.silence_threshold_spin_box.setFixedWidth(90)
|
||||
layout.addRow(_("Silence threshold:"), self.silence_threshold_spin_box)
|
||||
|
||||
# Live recording mode
|
||||
self.recording_mode_combo = QComboBox(self)
|
||||
for mode in RecordingTranscriberMode:
|
||||
self.recording_mode_combo.addItem(mode.value)
|
||||
self.recording_mode_combo.setCurrentIndex(
|
||||
self.settings.value(Settings.Key.RECORDING_TRANSCRIBER_MODE, 0)
|
||||
)
|
||||
self.recording_mode_combo.currentIndexChanged.connect(self.on_recording_mode_changed)
|
||||
self.recording_mode_combo.setFixedWidth(250)
|
||||
layout.addRow(_("Live recording mode") + ":", self.recording_mode_combo)
|
||||
|
||||
self.line_separator_line_edit = QLineEdit(self)
|
||||
line_sep_display = repr(transcription_options.line_separator)[1:-1] or r"\n\n"
|
||||
self.line_separator_line_edit.setText(line_sep_display)
|
||||
self.line_separator_line_edit.textChanged.connect(self.on_line_separator_changed)
|
||||
self.line_separator_label = QLabel(_("Line separator:"))
|
||||
layout.addRow(self.line_separator_label, self.line_separator_line_edit)
|
||||
|
||||
self.transcription_step_spin_box = QDoubleSpinBox(self)
|
||||
self.transcription_step_spin_box.setRange(2.0, 5.0)
|
||||
self.transcription_step_spin_box.setSingleStep(0.1)
|
||||
self.transcription_step_spin_box.setDecimals(1)
|
||||
self.transcription_step_spin_box.setValue(transcription_options.transcription_step)
|
||||
self.transcription_step_spin_box.valueChanged.connect(self.on_transcription_step_changed)
|
||||
self.transcription_step_spin_box.setFixedWidth(80)
|
||||
self.transcription_step_label = QLabel(_("Transcription step:"))
|
||||
layout.addRow(self.transcription_step_label, self.transcription_step_spin_box)
|
||||
|
||||
hide_unconfirmed = self.settings.value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_HIDE_UNCONFIRMED, True
|
||||
)
|
||||
self.hide_unconfirmed_checkbox = QCheckBox(_("Hide unconfirmed"))
|
||||
self.hide_unconfirmed_checkbox.setChecked(hide_unconfirmed)
|
||||
self.hide_unconfirmed_checkbox.stateChanged.connect(self.on_hide_unconfirmed_changed)
|
||||
self.hide_unconfirmed_label = QLabel("")
|
||||
layout.addRow(self.hide_unconfirmed_label, self.hide_unconfirmed_checkbox)
|
||||
|
||||
self._update_recording_mode_visibility(
|
||||
RecordingTranscriberMode(self.recording_mode_combo.currentText())
|
||||
)
|
||||
|
||||
# Export enabled checkbox
|
||||
self._export_enabled = self.settings.value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_EXPORT_ENABLED, False
|
||||
)
|
||||
self.export_enabled_checkbox = QCheckBox(_("Enable live recording export"))
|
||||
self.export_enabled_checkbox.setChecked(self._export_enabled)
|
||||
self.export_enabled_checkbox.stateChanged.connect(self.on_export_enabled_changed)
|
||||
layout.addRow("", self.export_enabled_checkbox)
|
||||
|
||||
# Export folder
|
||||
export_folder = self.settings.value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_EXPORT_FOLDER, ""
|
||||
)
|
||||
self.export_folder_line_edit = LineEdit(export_folder, self)
|
||||
self.export_folder_line_edit.setEnabled(self._export_enabled)
|
||||
self.export_folder_line_edit.textChanged.connect(self.on_export_folder_changed)
|
||||
self.export_folder_browse_button = QPushButton(_("Browse"), self)
|
||||
self.export_folder_browse_button.setEnabled(self._export_enabled)
|
||||
self.export_folder_browse_button.clicked.connect(self.on_browse_export_folder)
|
||||
export_folder_row = QHBoxLayout()
|
||||
export_folder_row.addWidget(self.export_folder_line_edit)
|
||||
export_folder_row.addWidget(self.export_folder_browse_button)
|
||||
self.export_folder_label = QLabel(_("Export folder:"))
|
||||
self.export_folder_label.setEnabled(self._export_enabled)
|
||||
layout.addRow(self.export_folder_label, export_folder_row)
|
||||
|
||||
# Export file name template
|
||||
export_file_name = self.settings.value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_EXPORT_FILE_NAME, ""
|
||||
)
|
||||
self.export_file_name_line_edit = LineEdit(export_file_name, self)
|
||||
self.export_file_name_line_edit.setEnabled(self._export_enabled)
|
||||
self.export_file_name_line_edit.textChanged.connect(self.on_export_file_name_changed)
|
||||
self.export_file_name_label = QLabel(_("Export file name:"))
|
||||
self.export_file_name_label.setEnabled(self._export_enabled)
|
||||
layout.addRow(self.export_file_name_label, self.export_file_name_line_edit)
|
||||
|
||||
# Export file type
|
||||
self.export_file_type_combo = QComboBox(self)
|
||||
self.export_file_type_combo.addItem(_("Text file (.txt)"), "txt")
|
||||
self.export_file_type_combo.addItem(_("CSV (.csv)"), "csv")
|
||||
current_type = self.settings.value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_EXPORT_FILE_TYPE, "txt"
|
||||
)
|
||||
type_index = self.export_file_type_combo.findData(current_type)
|
||||
if type_index >= 0:
|
||||
self.export_file_type_combo.setCurrentIndex(type_index)
|
||||
self.export_file_type_combo.setEnabled(self._export_enabled)
|
||||
self.export_file_type_combo.currentIndexChanged.connect(self.on_export_file_type_changed)
|
||||
self.export_file_type_combo.setFixedWidth(200)
|
||||
self.export_file_type_label = QLabel(_("Export file type:"))
|
||||
self.export_file_type_label.setEnabled(self._export_enabled)
|
||||
layout.addRow(self.export_file_type_label, self.export_file_type_combo)
|
||||
|
||||
# Max entries
|
||||
max_entries = self.settings.value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_EXPORT_MAX_ENTRIES, 0, int
|
||||
)
|
||||
self.export_max_entries_spin = QSpinBox(self)
|
||||
self.export_max_entries_spin.setRange(0, 99)
|
||||
self.export_max_entries_spin.setValue(max_entries)
|
||||
self.export_max_entries_spin.setEnabled(self._export_enabled)
|
||||
self.export_max_entries_spin.valueChanged.connect(self.on_export_max_entries_changed)
|
||||
self.export_max_entries_spin.setFixedWidth(90)
|
||||
self.export_max_entries_label = QLabel(_("Limit export entries\n(0 = export all):"))
|
||||
self.export_max_entries_label.setEnabled(self._export_enabled)
|
||||
layout.addRow(self.export_max_entries_label, self.export_max_entries_spin)
|
||||
|
||||
_field_height = self.llm_model_line_edit.sizeHint().height()
|
||||
for widget in (
|
||||
self.line_separator_line_edit,
|
||||
self.silence_threshold_spin_box,
|
||||
self.recording_mode_combo,
|
||||
self.transcription_step_spin_box,
|
||||
self.export_file_type_combo,
|
||||
self.export_max_entries_spin,
|
||||
):
|
||||
widget.setFixedHeight(_field_height)
|
||||
layout.addRow(_("Instructions for AI:"), self.llm_prompt_text_edit)
|
||||
|
||||
button_box = QDialogButtonBox(
|
||||
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self
|
||||
|
|
@ -235,6 +100,15 @@ class AdvancedSettingsDialog(QDialog):
|
|||
layout.addWidget(button_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.resize(self.sizeHint())
|
||||
|
||||
def on_temperature_changed(self, text: str):
|
||||
try:
|
||||
temperatures = [float(temp.strip()) for temp in text.split(",")]
|
||||
self.transcription_options.temperature = tuple(temperatures)
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def on_initial_prompt_changed(self):
|
||||
self.transcription_options.initial_prompt = (
|
||||
|
|
@ -246,11 +120,8 @@ class AdvancedSettingsDialog(QDialog):
|
|||
self.transcription_options.enable_llm_translation = state == 2
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
enabled = self.transcription_options.enable_llm_translation
|
||||
self.llm_model_label.setEnabled(enabled)
|
||||
self.llm_model_line_edit.setEnabled(enabled)
|
||||
self.llm_prompt_label.setEnabled(enabled)
|
||||
self.llm_prompt_text_edit.setEnabled(enabled)
|
||||
self.llm_model_line_edit.setEnabled(self.transcription_options.enable_llm_translation)
|
||||
self.llm_prompt_text_edit.setEnabled(self.transcription_options.enable_llm_translation)
|
||||
|
||||
def on_llm_model_changed(self, text: str):
|
||||
self.transcription_options.llm_model = text
|
||||
|
|
@ -261,72 +132,3 @@ class AdvancedSettingsDialog(QDialog):
|
|||
self.llm_prompt_text_edit.toPlainText()
|
||||
)
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_silence_threshold_changed(self, value: float):
|
||||
self.transcription_options.silence_threshold = value
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_line_separator_changed(self, text: str):
|
||||
try:
|
||||
self.transcription_options.line_separator = text.encode().decode("unicode_escape")
|
||||
except UnicodeDecodeError:
|
||||
return
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_recording_mode_changed(self, index: int):
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_MODE, index)
|
||||
mode = list(RecordingTranscriberMode)[index]
|
||||
self._update_recording_mode_visibility(mode)
|
||||
self.recording_mode_changed.emit(mode)
|
||||
|
||||
def _update_recording_mode_visibility(self, mode: RecordingTranscriberMode):
|
||||
is_append_and_correct = mode == RecordingTranscriberMode.APPEND_AND_CORRECT
|
||||
self.line_separator_label.setVisible(not is_append_and_correct)
|
||||
self.line_separator_line_edit.setVisible(not is_append_and_correct)
|
||||
self.transcription_step_label.setVisible(is_append_and_correct)
|
||||
self.transcription_step_spin_box.setVisible(is_append_and_correct)
|
||||
self.hide_unconfirmed_label.setVisible(is_append_and_correct)
|
||||
self.hide_unconfirmed_checkbox.setVisible(is_append_and_correct)
|
||||
|
||||
def on_transcription_step_changed(self, value: float):
|
||||
self.transcription_options.transcription_step = round(value, 1)
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_hide_unconfirmed_changed(self, state: int):
|
||||
value = state == 2
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_HIDE_UNCONFIRMED, value)
|
||||
self.hide_unconfirmed_changed.emit(value)
|
||||
|
||||
def on_export_enabled_changed(self, state: int):
|
||||
self._export_enabled = state == 2
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_EXPORT_ENABLED, self._export_enabled)
|
||||
for widget in (
|
||||
self.export_folder_label,
|
||||
self.export_folder_line_edit,
|
||||
self.export_folder_browse_button,
|
||||
self.export_file_name_label,
|
||||
self.export_file_name_line_edit,
|
||||
self.export_file_type_label,
|
||||
self.export_file_type_combo,
|
||||
self.export_max_entries_label,
|
||||
self.export_max_entries_spin,
|
||||
):
|
||||
widget.setEnabled(self._export_enabled)
|
||||
|
||||
def on_export_folder_changed(self, text: str):
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_EXPORT_FOLDER, text)
|
||||
|
||||
def on_browse_export_folder(self):
|
||||
folder = QFileDialog.getExistingDirectory(self, _("Select Export Folder"))
|
||||
if folder:
|
||||
self.export_folder_line_edit.setText(folder)
|
||||
|
||||
def on_export_file_name_changed(self, text: str):
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_EXPORT_FILE_NAME, text)
|
||||
|
||||
def on_export_file_type_changed(self, index: int):
|
||||
file_type = self.export_file_type_combo.itemData(index)
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_EXPORT_FILE_TYPE, file_type)
|
||||
|
||||
def on_export_max_entries_changed(self, value: int):
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_EXPORT_MAX_ENTRIES, value)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from PyQt6 import QtGui
|
||||
|
|
@ -11,7 +10,7 @@ from PyQt6.QtWidgets import (
|
|||
|
||||
from buzz.dialogs import show_model_download_error_dialog
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelDownloader, WhisperModelSize, ModelType
|
||||
from buzz.model_loader import ModelDownloader
|
||||
from buzz.paths import file_path_as_title
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.store.keyring_store import get_password, Key
|
||||
|
|
@ -77,10 +76,6 @@ class FileTranscriberWidget(QWidget):
|
|||
self.openai_access_token_changed
|
||||
)
|
||||
|
||||
self.form_widget.transcription_options_changed.connect(
|
||||
self.reset_transcriber_controls
|
||||
)
|
||||
|
||||
self.run_button = QPushButton(_("Run"), self)
|
||||
self.run_button.setDefault(True)
|
||||
self.run_button.clicked.connect(self.on_click_run)
|
||||
|
|
@ -159,17 +154,7 @@ class FileTranscriberWidget(QWidget):
|
|||
self.reset_transcriber_controls()
|
||||
|
||||
def reset_transcriber_controls(self):
|
||||
button_enabled = True
|
||||
if (self.transcription_options.model.model_type == ModelType.FASTER_WHISPER
|
||||
and self.transcription_options.model.whisper_model_size == WhisperModelSize.CUSTOM
|
||||
and self.transcription_options.model.hugging_face_model_id == ""):
|
||||
button_enabled = False
|
||||
|
||||
if (self.transcription_options.model.model_type == ModelType.HUGGING_FACE
|
||||
and self.transcription_options.model.hugging_face_model_id == ""):
|
||||
button_enabled = False
|
||||
|
||||
self.run_button.setEnabled(button_enabled)
|
||||
self.run_button.setDisabled(False)
|
||||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
self.reset_transcriber_controls()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal, Qt
|
||||
|
|
@ -51,16 +50,6 @@ class FileTranscriptionFormWidget(QWidget):
|
|||
file_transcription_layout = QFormLayout()
|
||||
file_transcription_layout.addRow("", self.word_level_timings_checkbox)
|
||||
|
||||
self.extract_speech_checkbox = QCheckBox(_("Extract speech"))
|
||||
self.extract_speech_checkbox.setChecked(
|
||||
self.transcription_options.extract_speech
|
||||
)
|
||||
self.extract_speech_checkbox.stateChanged.connect(
|
||||
self.on_extract_speech_changed
|
||||
)
|
||||
|
||||
file_transcription_layout.addRow("", self.extract_speech_checkbox)
|
||||
|
||||
export_format_layout = QHBoxLayout()
|
||||
for output_format in OutputFormat:
|
||||
export_format_checkbox = QCheckBox(
|
||||
|
|
@ -80,10 +69,13 @@ class FileTranscriptionFormWidget(QWidget):
|
|||
layout.addLayout(file_transcription_layout)
|
||||
self.setLayout(layout)
|
||||
|
||||
self.reset_word_level_timings()
|
||||
|
||||
def on_transcription_options_changed(
|
||||
self, transcription_options: TranscriptionOptions
|
||||
):
|
||||
self.transcription_options = transcription_options
|
||||
self.reset_word_level_timings()
|
||||
self.transcription_options_changed.emit(
|
||||
(self.transcription_options, self.file_transcription_options)
|
||||
)
|
||||
|
|
@ -101,15 +93,6 @@ class FileTranscriptionFormWidget(QWidget):
|
|||
(self.transcription_options, self.file_transcription_options)
|
||||
)
|
||||
|
||||
def on_extract_speech_changed(self, value: int):
|
||||
self.transcription_options.extract_speech = (
|
||||
value == Qt.CheckState.Checked.value
|
||||
)
|
||||
|
||||
self.transcription_options_changed.emit(
|
||||
(self.transcription_options, self.file_transcription_options)
|
||||
)
|
||||
|
||||
def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat):
|
||||
def on_checkbox_state_changed(state: int):
|
||||
if state == Qt.CheckState.Checked.value:
|
||||
|
|
@ -122,3 +105,9 @@ class FileTranscriptionFormWidget(QWidget):
|
|||
)
|
||||
|
||||
return on_checkbox_state_changed
|
||||
|
||||
def reset_word_level_timings(self):
|
||||
self.word_level_timings_checkbox.setDisabled(
|
||||
self.transcription_options.model.model_type
|
||||
== ModelType.OPEN_AI_WHISPER_API
|
||||
)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
|
||||
self.setPlaceholderText(_("Huggingface ID of a model"))
|
||||
|
||||
self.setMinimumWidth(50)
|
||||
self.setMinimumWidth(255)
|
||||
|
||||
self.timer = QTimer(self)
|
||||
self.timer.setSingleShot(True)
|
||||
|
|
@ -64,8 +64,7 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
|
||||
def focusInEvent(self, event):
|
||||
super().focusInEvent(event)
|
||||
# Defer selectAll to run after mouse events are processed
|
||||
QTimer.singleShot(0, self.selectAll)
|
||||
self.clear()
|
||||
|
||||
def on_text_edited(self, text: str):
|
||||
self.model_selected.emit(text)
|
||||
|
|
|
|||
|
|
@ -10,4 +10,4 @@ class InitialPromptTextEdit(QPlainTextEdit):
|
|||
self.setPlaceholderText(_("Enter prompt..."))
|
||||
self.setEnabled(model_type.supports_initial_prompt)
|
||||
self.setMinimumWidth(350)
|
||||
self.setFixedHeight(80)
|
||||
self.setFixedHeight(115)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||
import os
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal, Qt
|
||||
from PyQt6.QtWidgets import QComboBox, QWidget, QFrame
|
||||
from PyQt6.QtWidgets import QComboBox, QWidget
|
||||
from PyQt6.QtGui import QStandardItem, QStandardItemModel
|
||||
|
||||
from buzz.locale import _
|
||||
|
|
@ -51,9 +51,3 @@ class LanguagesComboBox(QComboBox):
|
|||
|
||||
def on_index_changed(self, index: int):
|
||||
self.languageChanged.emit(self.languages[index][0])
|
||||
|
||||
def showPopup(self):
|
||||
super().showPopup()
|
||||
popup = self.findChild(QFrame)
|
||||
if popup and popup.height() > 400:
|
||||
popup.setFixedHeight(400)
|
||||
|
|
|
|||
|
|
@ -1,48 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtWidgets import QWidget, QSizePolicy
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
|
||||
|
||||
class MMSLanguageLineEdit(LineEdit):
|
||||
"""Text input for MMS language codes (ISO 639-3).
|
||||
|
||||
MMS models support 1000+ languages using ISO 639-3 codes (3 letters).
|
||||
Examples: eng (English), fra (French), deu (German), spa (Spanish)
|
||||
"""
|
||||
|
||||
languageChanged = pyqtSignal(str)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_language: str = "eng",
|
||||
parent: Optional[QWidget] = None
|
||||
):
|
||||
super().__init__(default_language, parent)
|
||||
self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
|
||||
self.setPlaceholderText(_("e.g., eng, fra, deu"))
|
||||
self.setToolTip(
|
||||
_("Enter an ISO 639-3 language code (3 letters).\n"
|
||||
"Examples: eng (English), fra (French), deu (German),\n"
|
||||
"spa (Spanish), lav (Latvian)")
|
||||
)
|
||||
self.setMaxLength(10) # Allow some flexibility for edge cases
|
||||
self.setMinimumWidth(100)
|
||||
|
||||
self.textChanged.connect(self._on_text_changed)
|
||||
|
||||
def _on_text_changed(self, text: str):
|
||||
"""Emit language changed signal with cleaned text."""
|
||||
cleaned = text.strip().lower()
|
||||
self.languageChanged.emit(cleaned)
|
||||
|
||||
def language(self) -> str:
|
||||
"""Get the current language code."""
|
||||
return self.text().strip().lower()
|
||||
|
||||
def setLanguage(self, language: str):
|
||||
"""Set the language code."""
|
||||
self.setText(language.strip().lower() if language else "eng")
|
||||
21
buzz/widgets/transcriber/temperature_validator.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from PyQt6.QtCore import QObject
|
||||
from PyQt6.QtGui import QValidator
|
||||
|
||||
|
||||
class TemperatureValidator(QValidator):
|
||||
def __init__(self, parent: Optional[QObject] = ...) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
def validate(
|
||||
self, text: str, cursor_position: int
|
||||
) -> Tuple["QValidator.State", str, int]:
|
||||
try:
|
||||
temp_strings = [temp.strip() for temp in text.split(",")]
|
||||
if temp_strings[-1] == "":
|
||||
return QValidator.State.Intermediate, text, cursor_position
|
||||
_ = [float(temp) for temp in temp_strings]
|
||||
return QValidator.State.Acceptable, text, cursor_position
|
||||
except ValueError:
|
||||
return QValidator.State.Invalid, text, cursor_position
|
||||
|
|
@ -3,14 +3,14 @@ import logging
|
|||
import platform
|
||||
from typing import Optional, List
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal, QLocale
|
||||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtGui import QIcon
|
||||
from PyQt6.QtWidgets import QGroupBox, QWidget, QFormLayout, QComboBox, QLabel, QHBoxLayout
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.widgets.icon import INFO_ICON_PATH
|
||||
from buzz.model_loader import ModelType, WhisperModelSize, get_whisper_cpp_file_path, is_mms_model
|
||||
from buzz.model_loader import ModelType, WhisperModelSize, get_whisper_cpp_file_path
|
||||
from buzz.transcriber.transcriber import TranscriptionOptions, Task
|
||||
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
|
||||
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
|
||||
|
|
@ -20,7 +20,6 @@ from buzz.widgets.transcriber.hugging_face_search_line_edit import (
|
|||
HuggingFaceSearchLineEdit,
|
||||
)
|
||||
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
|
||||
from buzz.widgets.transcriber.mms_language_line_edit import MMSLanguageLineEdit
|
||||
from buzz.widgets.transcriber.tasks_combo_box import TasksComboBox
|
||||
|
||||
|
||||
|
|
@ -33,11 +32,9 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
|
||||
model_types: Optional[List[ModelType]] = None,
|
||||
parent: Optional[QWidget] = None,
|
||||
show_recording_settings: bool = False,
|
||||
):
|
||||
super().__init__(title="", parent=parent)
|
||||
self.settings = Settings()
|
||||
self.ui_locale = self.settings.value(Settings.Key.UI_LOCALE, QLocale().name())
|
||||
self.transcription_options = default_transcription_options
|
||||
|
||||
self.form_layout = QFormLayout(self)
|
||||
|
|
@ -50,9 +47,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
|
||||
|
||||
self.advanced_settings_dialog = AdvancedSettingsDialog(
|
||||
transcription_options=self.transcription_options,
|
||||
parent=self,
|
||||
show_recording_settings=show_recording_settings,
|
||||
transcription_options=self.transcription_options, parent=self
|
||||
)
|
||||
self.advanced_settings_dialog.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed
|
||||
|
|
@ -60,7 +55,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
|
||||
self.whisper_model_size_combo_box = QComboBox(self)
|
||||
self.whisper_model_size_combo_box.addItems(
|
||||
[size.value.title() for size in WhisperModelSize if size not in {WhisperModelSize.CUSTOM, WhisperModelSize.LUMII}]
|
||||
[size.value.title() for size in WhisperModelSize if size not in {WhisperModelSize.CUSTOM}]
|
||||
)
|
||||
self.whisper_model_size_combo_box.currentTextChanged.connect(
|
||||
self.on_whisper_model_size_changed
|
||||
|
|
@ -91,13 +86,6 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
)
|
||||
self.languages_combo_box.languageChanged.connect(self.on_language_changed)
|
||||
|
||||
# MMS language input (text field for ISO 639-3 codes)
|
||||
self.mms_language_line_edit = MMSLanguageLineEdit(
|
||||
default_language="eng", parent=self
|
||||
)
|
||||
self.mms_language_line_edit.languageChanged.connect(self.on_mms_language_changed)
|
||||
self.mms_language_line_edit.setVisible(False)
|
||||
|
||||
self.advanced_settings_button = AdvancedSettingsButton(self)
|
||||
self.advanced_settings_button.clicked.connect(self.open_advanced_settings)
|
||||
|
||||
|
|
@ -126,7 +114,6 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.form_layout.addRow(_("Api Key:"), self.openai_access_token_edit)
|
||||
self.form_layout.addRow(_("Task:"), self.tasks_combo_box)
|
||||
self.form_layout.addRow(_("Language:"), self.languages_combo_box)
|
||||
self.form_layout.addRow(_("Language:"), self.mms_language_line_edit)
|
||||
|
||||
self.reset_visible_rows()
|
||||
|
||||
|
|
@ -145,14 +132,6 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.transcription_options.language = language
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_mms_language_changed(self, language: str):
|
||||
"""Handle MMS language code changes."""
|
||||
if language == "":
|
||||
language = "eng" # Default to English for MMS
|
||||
|
||||
self.transcription_options.language = language
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
def on_task_changed(self, task: Task):
|
||||
self.transcription_options.task = task
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
|
@ -209,18 +188,6 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
WhisperModelSize.CUSTOM.value.title()
|
||||
)
|
||||
|
||||
# Leave LUMII model only for Latvian whisper_cpp
|
||||
lumii_model_index = (self.whisper_model_size_combo_box
|
||||
.findText(WhisperModelSize.LUMII.value.title()))
|
||||
|
||||
if lumii_model_index != -1 and (model_type != ModelType.WHISPER_CPP or self.ui_locale != "lv_LV"):
|
||||
self.whisper_model_size_combo_box.removeItem(lumii_model_index)
|
||||
|
||||
if lumii_model_index == -1 and model_type == ModelType.WHISPER_CPP and self.ui_locale == "lv_LV":
|
||||
self.whisper_model_size_combo_box.addItem(
|
||||
WhisperModelSize.LUMII.value.title()
|
||||
)
|
||||
|
||||
self.whisper_model_size_combo_box.setCurrentText(
|
||||
self.transcription_options.model.whisper_model_size.value.title()
|
||||
)
|
||||
|
|
@ -249,18 +216,11 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.transcription_options.model.model_type == ModelType.WHISPER_CPP
|
||||
)
|
||||
|
||||
# Update language widget visibility (MMS vs Whisper)
|
||||
self._update_language_widget_visibility()
|
||||
|
||||
def on_model_type_changed(self, model_type: ModelType):
|
||||
self.transcription_options.model.model_type = model_type
|
||||
if not model_type.supports_initial_prompt:
|
||||
self.transcription_options.initial_prompt = ""
|
||||
|
||||
if (self.transcription_options.model.whisper_model_size == WhisperModelSize.LUMII
|
||||
and model_type != ModelType.WHISPER_CPP):
|
||||
self.transcription_options.model.whisper_model_size = WhisperModelSize.LARGEV3TURBO
|
||||
|
||||
self.reset_visible_rows()
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
|
|
@ -277,34 +237,3 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
||||
self.settings.save_custom_model_id(self.transcription_options.model)
|
||||
|
||||
# Update language widget visibility based on whether this is an MMS model
|
||||
self._update_language_widget_visibility()
|
||||
|
||||
def _update_language_widget_visibility(self):
|
||||
"""Update language widget visibility based on whether the selected model is MMS."""
|
||||
model_type = self.transcription_options.model.model_type
|
||||
model_id = self.transcription_options.model.hugging_face_model_id
|
||||
|
||||
# Check if this is an MMS model
|
||||
is_mms = (model_type == ModelType.HUGGING_FACE and is_mms_model(model_id))
|
||||
|
||||
# Show MMS language input for MMS models, show dropdown for others
|
||||
self.form_layout.setRowVisible(self.mms_language_line_edit, is_mms)
|
||||
self.form_layout.setRowVisible(self.languages_combo_box, not is_mms)
|
||||
|
||||
# Sync the language value when switching between MMS and non-MMS
|
||||
if is_mms:
|
||||
# When switching to MMS, use the MMS language input value
|
||||
mms_lang = self.mms_language_line_edit.language()
|
||||
if mms_lang:
|
||||
self.transcription_options.language = mms_lang
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
else:
|
||||
# When switching from MMS to a regular model, use the dropdown's current value
|
||||
# This prevents invalid MMS language codes (like "eng") being used with Whisper
|
||||
current_index = self.languages_combo_box.currentIndex()
|
||||
dropdown_lang = self.languages_combo_box.languages[current_index][0]
|
||||
if self.transcription_options.language != dropdown_lang:
|
||||
self.transcription_options.language = dropdown_lang if dropdown_lang else None
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
|
|
|||
|
|
@ -6,17 +6,10 @@ from PyQt6.QtCore import QFileSystemWatcher, pyqtSignal, QObject
|
|||
|
||||
from buzz.store.keyring_store import Key, get_password
|
||||
from buzz.transcriber.transcriber import FileTranscriptionTask
|
||||
from buzz.model_loader import ModelDownloader
|
||||
from buzz.widgets.preferences_dialog.models.folder_watch_preferences import (
|
||||
FolderWatchPreferences,
|
||||
)
|
||||
|
||||
# Supported media file extensions (audio and video)
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
".mp3", ".wav", ".m4a", ".ogg", ".opus", ".flac", # audio
|
||||
".mp4", ".webm", ".ogm", ".mov", ".mkv", ".avi", ".wmv", # video
|
||||
}
|
||||
|
||||
|
||||
class TranscriptionTaskFolderWatcher(QFileSystemWatcher):
|
||||
preferences: FolderWatchPreferences
|
||||
|
|
@ -40,14 +33,9 @@ class TranscriptionTaskFolderWatcher(QFileSystemWatcher):
|
|||
if len(self.directories()) > 0:
|
||||
self.removePaths(self.directories())
|
||||
if preferences.enabled:
|
||||
# Add the input directory and all subdirectories to the watcher
|
||||
for dirpath, dirnames, _ in os.walk(preferences.input_directory):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
self.addPath(dirpath)
|
||||
self.addPath(preferences.input_directory)
|
||||
logging.debug(
|
||||
'Watching for media files in "%s" and subdirectories',
|
||||
preferences.input_directory,
|
||||
'Watching for media files in "%s"', preferences.input_directory
|
||||
)
|
||||
|
||||
def find_tasks(self):
|
||||
|
|
@ -60,18 +48,8 @@ class TranscriptionTaskFolderWatcher(QFileSystemWatcher):
|
|||
for dirpath, dirnames, filenames in os.walk(input_directory):
|
||||
for filename in filenames:
|
||||
file_path = os.path.join(dirpath, filename)
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
|
||||
# Check for temp conversion files (e.g., .ogg.wav)
|
||||
name_without_ext = os.path.splitext(filename)[0]
|
||||
secondary_ext = os.path.splitext(name_without_ext)[1].lower()
|
||||
is_temp_conversion_file = secondary_ext in SUPPORTED_EXTENSIONS
|
||||
|
||||
if (
|
||||
filename.startswith(".") # hidden files
|
||||
or file_ext not in SUPPORTED_EXTENSIONS # non-media files
|
||||
or is_temp_conversion_file # temp conversion files like .ogg.wav
|
||||
or "_speech.mp3" in filename # extracted speech output files
|
||||
or file_path in tasks # file already in tasks
|
||||
or file_path in self.paths_emitted # file already emitted
|
||||
):
|
||||
|
|
@ -86,39 +64,16 @@ class TranscriptionTaskFolderWatcher(QFileSystemWatcher):
|
|||
file_paths=[file_path],
|
||||
)
|
||||
model_path = transcription_options.model.get_local_model_path()
|
||||
|
||||
if model_path is None:
|
||||
ModelDownloader(model=transcription_options.model).run()
|
||||
model_path = transcription_options.model.get_local_model_path()
|
||||
|
||||
# Preserve subdirectory structure in output directory
|
||||
relative_path = os.path.relpath(dirpath, input_directory)
|
||||
if relative_path == ".":
|
||||
output_directory = self.preferences.output_directory
|
||||
else:
|
||||
output_directory = os.path.join(
|
||||
self.preferences.output_directory, relative_path
|
||||
)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
task = FileTranscriptionTask(
|
||||
file_path=file_path,
|
||||
original_file_path=file_path,
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
model_path=model_path,
|
||||
output_directory=output_directory,
|
||||
output_directory=self.preferences.output_directory,
|
||||
source=FileTranscriptionTask.Source.FOLDER_WATCH,
|
||||
delete_source_file=self.preferences.delete_processed_files,
|
||||
)
|
||||
self.task_found.emit(task)
|
||||
self.paths_emitted.add(file_path)
|
||||
|
||||
# Filter out hidden directories and add new subdirectories to the watcher
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
for dirname in dirnames:
|
||||
subdir_path = os.path.join(dirpath, dirname)
|
||||
if subdir_path not in self.directories():
|
||||
self.addPath(subdir_path)
|
||||
# Don't traverse into subdirectories
|
||||
break
|
||||
|
|
|
|||
|
|
@ -32,26 +32,21 @@ from buzz.widgets.transcription_record import TranscriptionRecord
|
|||
|
||||
class Column(enum.Enum):
|
||||
ID = 0
|
||||
ERROR_MESSAGE = 1
|
||||
EXPORT_FORMATS = 2
|
||||
FILE = 3
|
||||
OUTPUT_FOLDER = 4
|
||||
PROGRESS = 5
|
||||
LANGUAGE = 6
|
||||
MODEL_TYPE = 7
|
||||
SOURCE = 8
|
||||
STATUS = 9
|
||||
TASK = 10
|
||||
TIME_ENDED = 11
|
||||
TIME_QUEUED = 12
|
||||
TIME_STARTED = 13
|
||||
URL = 14
|
||||
WHISPER_MODEL_SIZE = 15
|
||||
HUGGING_FACE_MODEL_ID = 16
|
||||
WORD_LEVEL_TIMINGS = 17
|
||||
EXTRACT_SPEECH = 18
|
||||
NAME = 19
|
||||
NOTES = 20
|
||||
ERROR_MESSAGE = auto()
|
||||
EXPORT_FORMATS = auto()
|
||||
FILE = auto()
|
||||
OUTPUT_FOLDER = auto()
|
||||
PROGRESS = auto()
|
||||
LANGUAGE = auto()
|
||||
MODEL_TYPE = auto()
|
||||
SOURCE = auto()
|
||||
STATUS = auto()
|
||||
TASK = auto()
|
||||
TIME_ENDED = auto()
|
||||
TIME_QUEUED = auto()
|
||||
TIME_STARTED = auto()
|
||||
URL = auto()
|
||||
WHISPER_MODEL_SIZE = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -84,7 +79,7 @@ def format_record_status_text(record: QSqlRecord) -> str:
|
|||
return _("Canceled")
|
||||
case FileTranscriptionTask.Status.QUEUED:
|
||||
return _("Queued")
|
||||
case _: # Case to handle UNKNOWN status
|
||||
case _:
|
||||
return ""
|
||||
|
||||
column_definitions = [
|
||||
|
|
@ -94,10 +89,9 @@ column_definitions = [
|
|||
column=Column.FILE,
|
||||
width=400,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: record.value("name") or (
|
||||
os.path.basename(record.value("file")) if record.value("file")
|
||||
else record.value("url") or ""
|
||||
)
|
||||
text_getter=lambda record: record.value("url")
|
||||
if record.value("url") != ""
|
||||
else os.path.basename(record.value("file"))
|
||||
),
|
||||
hidden_toggleable=False,
|
||||
),
|
||||
|
|
@ -113,7 +107,7 @@ column_definitions = [
|
|||
ColDef(
|
||||
id="task",
|
||||
header=_("Task"),
|
||||
column=Column.TASK,
|
||||
column=Column.SOURCE,
|
||||
width=120,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: TASK_LABEL_TRANSLATIONS[Task(record.value("task"))]
|
||||
|
|
@ -125,9 +119,19 @@ column_definitions = [
|
|||
column=Column.STATUS,
|
||||
width=180,
|
||||
delegate=RecordDelegate(text_getter=format_record_status_text),
|
||||
hidden_toggleable=True,
|
||||
hidden_toggleable=False,
|
||||
),
|
||||
ColDef(
|
||||
id="date_added",
|
||||
header=_("Date Added"),
|
||||
column=Column.TIME_QUEUED,
|
||||
width=180,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: datetime.fromisoformat(
|
||||
record.value("time_queued")
|
||||
).strftime("%Y-%m-%d %H:%M:%S")
|
||||
),
|
||||
),
|
||||
|
||||
ColDef(
|
||||
id="date_completed",
|
||||
header=_("Date Completed"),
|
||||
|
|
@ -140,26 +144,6 @@ column_definitions = [
|
|||
if record.value("time_ended") != ""
|
||||
else ""
|
||||
),
|
||||
), ColDef(
|
||||
id="date_added",
|
||||
header=_("Date Added"),
|
||||
column=Column.TIME_QUEUED,
|
||||
width=180,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: datetime.fromisoformat(
|
||||
record.value("time_queued")
|
||||
).strftime("%Y-%m-%d %H:%M:%S")
|
||||
),
|
||||
),
|
||||
ColDef(
|
||||
id="notes",
|
||||
header=_("Notes"),
|
||||
column=Column.NOTES,
|
||||
width=300,
|
||||
delegate=RecordDelegate(
|
||||
text_getter=lambda record: record.value("notes") or ""
|
||||
),
|
||||
hidden_toggleable=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -169,72 +153,28 @@ class TranscriptionTasksTableHeaderView(QHeaderView):
|
|||
|
||||
def contextMenuEvent(self, event):
|
||||
menu = QMenu(self)
|
||||
|
||||
# Add reset column order option
|
||||
menu.addAction(_("Reset Column Order")).triggered.connect(self.parent().reset_column_order)
|
||||
menu.addSeparator()
|
||||
|
||||
# Add column visibility toggles
|
||||
for definition in column_definitions:
|
||||
if definition.hidden_toggleable:
|
||||
action = menu.addAction(definition.header)
|
||||
action.setCheckable(True)
|
||||
action.setChecked(not self.parent().isColumnHidden(definition.column.value))
|
||||
action.toggled.connect(
|
||||
lambda checked, column_index=definition.column.value: self.on_column_checked(
|
||||
column_index, checked
|
||||
)
|
||||
if not definition.hidden_toggleable:
|
||||
continue
|
||||
action = menu.addAction(definition.header)
|
||||
action.setCheckable(True)
|
||||
action.setChecked(not self.isSectionHidden(definition.column.value))
|
||||
action.toggled.connect(
|
||||
lambda checked, column_index=definition.column.value: self.on_column_checked(
|
||||
column_index, checked
|
||||
)
|
||||
)
|
||||
menu.exec(event.globalPos())
|
||||
|
||||
def on_column_checked(self, column_index: int, checked: bool):
|
||||
# Find the column definition for this index
|
||||
column_def = None
|
||||
for definition in column_definitions:
|
||||
if definition.column.value == column_index:
|
||||
column_def = definition
|
||||
break
|
||||
|
||||
# If we're hiding the column, save its current width first
|
||||
if not checked and not self.parent().isColumnHidden(column_index):
|
||||
current_width = self.parent().columnWidth(column_index)
|
||||
if current_width > 0: # Only save if there's a meaningful width
|
||||
self.parent().settings.begin_group(self.parent().settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_WIDTHS)
|
||||
self.parent().settings.settings.setValue(column_def.id, current_width)
|
||||
self.parent().settings.end_group()
|
||||
|
||||
# Update the visibility state on the table view (not header view)
|
||||
self.parent().setColumnHidden(column_index, not checked)
|
||||
|
||||
# Save current column order before any reloading
|
||||
self.parent().save_column_order()
|
||||
|
||||
# Save both visibility and widths after the change
|
||||
self.setSectionHidden(column_index, not checked)
|
||||
self.parent().save_column_visibility()
|
||||
self.parent().save_column_widths()
|
||||
|
||||
# Ensure settings are synchronized
|
||||
self.parent().settings.settings.sync()
|
||||
|
||||
# Force a complete refresh of the table
|
||||
self.parent().viewport().update()
|
||||
self.parent().repaint()
|
||||
self.parent().horizontalHeader().update()
|
||||
self.parent().updateGeometry()
|
||||
self.parent().adjustSize()
|
||||
|
||||
# Force a model refresh to ensure the view is updated
|
||||
self.parent().model().layoutChanged.emit()
|
||||
|
||||
self.parent().reload_column_order_from_settings()
|
||||
|
||||
class TranscriptionTasksTableWidget(QTableView):
|
||||
return_clicked = pyqtSignal()
|
||||
delete_requested = pyqtSignal()
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
self.transcription_service = None
|
||||
|
||||
self.setHorizontalHeader(TranscriptionTasksTableHeaderView(Qt.Orientation.Horizontal, self))
|
||||
|
||||
|
|
@ -250,69 +190,58 @@ class TranscriptionTasksTableWidget(QTableView):
|
|||
|
||||
self.settings = Settings()
|
||||
|
||||
# Set up column headers and delegates
|
||||
self.settings.begin_group(
|
||||
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
|
||||
)
|
||||
for definition in column_definitions:
|
||||
self.model().setHeaderData(
|
||||
definition.column.value,
|
||||
Qt.Orientation.Horizontal,
|
||||
definition.header,
|
||||
)
|
||||
|
||||
visible = True
|
||||
if definition.hidden_toggleable:
|
||||
visible = self.settings.settings.value(definition.id, "true") in {"true", "True", True}
|
||||
|
||||
self.setColumnHidden(definition.column.value, not visible)
|
||||
if definition.width is not None:
|
||||
self.setColumnWidth(definition.column.value, definition.width)
|
||||
if definition.delegate is not None:
|
||||
self.setItemDelegateForColumn(
|
||||
definition.column.value, definition.delegate
|
||||
)
|
||||
|
||||
# Load column visibility
|
||||
self.load_column_visibility()
|
||||
self.settings.end_group()
|
||||
|
||||
self.model().select()
|
||||
self.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
|
||||
self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
|
||||
self.verticalHeader().hide()
|
||||
self.setAlternatingRowColors(True)
|
||||
|
||||
# Enable column sorting and moving
|
||||
self.setSortingEnabled(True)
|
||||
self.horizontalHeader().setSectionsMovable(True)
|
||||
self.horizontalHeader().setSectionsClickable(True)
|
||||
self.horizontalHeader().setSortIndicatorShown(True)
|
||||
|
||||
# Connect signals for column resize and move
|
||||
self.horizontalHeader().sectionResized.connect(self.on_column_resized)
|
||||
self.horizontalHeader().sectionMoved.connect(self.on_column_moved)
|
||||
self.horizontalHeader().sortIndicatorChanged.connect(self.on_sort_indicator_changed)
|
||||
|
||||
# Load saved column order, widths, and sort state
|
||||
self.load_column_order()
|
||||
self.load_column_widths()
|
||||
self.load_sort_state()
|
||||
|
||||
|
||||
# Reload column visibility after all reordering is complete
|
||||
self.load_column_visibility()
|
||||
# Show date added before date completed
|
||||
self.horizontalHeader().swapSections(11, 12)
|
||||
|
||||
def contextMenuEvent(self, event):
|
||||
menu = QMenu(self)
|
||||
|
||||
# Add transcription actions if a row is selected
|
||||
selected_rows = self.selectionModel().selectedRows()
|
||||
if selected_rows:
|
||||
transcription = self.transcription(selected_rows[0])
|
||||
|
||||
# Add restart/continue action for failed/canceled tasks
|
||||
if transcription.status in ["failed", "canceled"]:
|
||||
restart_action = menu.addAction(_("Restart Transcription"))
|
||||
restart_action.triggered.connect(self.on_restart_transcription_action)
|
||||
menu.addSeparator()
|
||||
|
||||
rename_action = menu.addAction(_("Rename"))
|
||||
rename_action.triggered.connect(self.on_rename_action)
|
||||
|
||||
notes_action = menu.addAction(_("Add/Edit Notes"))
|
||||
notes_action.triggered.connect(self.on_notes_action)
|
||||
|
||||
for definition in column_definitions:
|
||||
if not definition.hidden_toggleable:
|
||||
continue
|
||||
action = menu.addAction(definition.header)
|
||||
action.setCheckable(True)
|
||||
action.setChecked(not self.isColumnHidden(definition.column.value))
|
||||
action.toggled.connect(
|
||||
lambda checked,
|
||||
column_index=definition.column.value: self.on_column_checked(
|
||||
column_index, checked
|
||||
)
|
||||
)
|
||||
menu.exec(event.globalPos())
|
||||
|
||||
def on_column_checked(self, column_index: int, checked: bool):
|
||||
self.setColumnHidden(column_index, not checked)
|
||||
self.save_column_visibility()
|
||||
|
||||
def save_column_visibility(self):
|
||||
self.settings.begin_group(
|
||||
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
|
||||
|
|
@ -323,257 +252,22 @@ class TranscriptionTasksTableWidget(QTableView):
|
|||
)
|
||||
self.settings.end_group()
|
||||
|
||||
def on_column_resized(self, logical_index: int, old_size: int, new_size: int):
|
||||
"""Handle column resize events"""
|
||||
self.save_column_widths()
|
||||
|
||||
def on_column_moved(self, logical_index: int, old_visual_index: int, new_visual_index: int):
|
||||
"""Handle column move events"""
|
||||
self.save_column_order()
|
||||
# Refresh visibility after column move to ensure it's maintained
|
||||
self.load_column_visibility()
|
||||
|
||||
def on_sort_indicator_changed(self, logical_index: int, order: Qt.SortOrder):
|
||||
"""Handle sort indicator change events"""
|
||||
self.save_sort_state()
|
||||
|
||||
def on_double_click(self, index: QModelIndex):
|
||||
"""Handle double-click events - trigger notes edit for notes column"""
|
||||
if index.column() == Column.NOTES.value:
|
||||
self.on_notes_action()
|
||||
|
||||
def save_column_widths(self):
|
||||
"""Save current column widths to settings"""
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_WIDTHS)
|
||||
for definition in column_definitions:
|
||||
# Only save width if column is visible and has a meaningful width
|
||||
if not self.isColumnHidden(definition.column.value):
|
||||
width = self.columnWidth(definition.column.value)
|
||||
if width > 0: # Only save if there's a meaningful width
|
||||
self.settings.settings.setValue(definition.id, width)
|
||||
self.settings.end_group()
|
||||
|
||||
def save_column_order(self):
|
||||
"""Save current column order to settings"""
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_ORDER)
|
||||
header = self.horizontalHeader()
|
||||
for visual_index in range(header.count()):
|
||||
logical_index = header.logicalIndex(visual_index)
|
||||
# Find the column definition for this logical index
|
||||
for definition in column_definitions:
|
||||
if definition.column.value == logical_index:
|
||||
self.settings.settings.setValue(definition.id, visual_index)
|
||||
break
|
||||
self.settings.end_group()
|
||||
|
||||
def load_column_widths(self):
|
||||
"""Load saved column widths from settings"""
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_WIDTHS)
|
||||
for definition in column_definitions:
|
||||
if definition.width is not None: # Only load if column has a default width
|
||||
saved_width = self.settings.settings.value(definition.id, definition.width)
|
||||
if saved_width is not None:
|
||||
self.setColumnWidth(definition.column.value, int(saved_width))
|
||||
self.settings.end_group()
|
||||
|
||||
def save_sort_state(self):
|
||||
"""Save current sort state to settings"""
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_SORT_STATE)
|
||||
header = self.horizontalHeader()
|
||||
self.settings.settings.setValue("column", header.sortIndicatorSection())
|
||||
self.settings.settings.setValue("order", header.sortIndicatorOrder().value)
|
||||
self.settings.end_group()
|
||||
|
||||
def load_sort_state(self):
|
||||
"""Load saved sort state from settings"""
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_SORT_STATE)
|
||||
column = self.settings.settings.value("column")
|
||||
order = self.settings.settings.value("order")
|
||||
self.settings.end_group()
|
||||
|
||||
if column is not None and order is not None:
|
||||
sort_order = Qt.SortOrder(int(order))
|
||||
self.sortByColumn(int(column), sort_order)
|
||||
|
||||
def load_column_visibility(self):
|
||||
"""Load saved column visibility from settings"""
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY)
|
||||
for definition in column_definitions:
|
||||
visible = True
|
||||
if definition.hidden_toggleable:
|
||||
value = self.settings.settings.value(definition.id, "true")
|
||||
visible = value in {"true", "True", True}
|
||||
|
||||
self.setColumnHidden(definition.column.value, not visible)
|
||||
self.settings.end_group()
|
||||
|
||||
# Force a refresh of the table layout
|
||||
self.horizontalHeader().update()
|
||||
self.viewport().update()
|
||||
self.updateGeometry()
|
||||
|
||||
def load_column_order(self):
|
||||
"""Load saved column order from settings"""
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_ORDER)
|
||||
|
||||
# Create a mapping of column IDs to their saved visual positions
|
||||
column_positions = {}
|
||||
for definition in column_definitions:
|
||||
saved_position = self.settings.settings.value(definition.id)
|
||||
if saved_position is not None:
|
||||
column_positions[definition.column.value] = int(saved_position)
|
||||
|
||||
self.settings.end_group()
|
||||
|
||||
# Apply the saved order
|
||||
if column_positions:
|
||||
header = self.horizontalHeader()
|
||||
for logical_index, visual_position in column_positions.items():
|
||||
if 0 <= visual_position < header.count():
|
||||
header.moveSection(header.visualIndex(logical_index), visual_position)
|
||||
|
||||
def reset_column_order(self):
|
||||
"""Reset column order to default"""
|
||||
|
||||
# Reset column widths to defaults
|
||||
for definition in column_definitions:
|
||||
if definition.width is not None:
|
||||
self.setColumnWidth(definition.column.value, definition.width)
|
||||
|
||||
# Show all columns
|
||||
for definition in column_definitions:
|
||||
self.setColumnHidden(definition.column.value, False)
|
||||
|
||||
# Restore default column order
|
||||
header = self.horizontalHeader()
|
||||
# Move each section to its default position in order
|
||||
# To avoid index shifting, move from left to right
|
||||
for target_visual_index, definition in enumerate(column_definitions):
|
||||
logical_index = definition.column.value
|
||||
current_visual_index = header.visualIndex(logical_index)
|
||||
if current_visual_index != target_visual_index:
|
||||
header.moveSection(current_visual_index, target_visual_index)
|
||||
|
||||
# Reset sort to default (TIME_QUEUED descending)
|
||||
self.sortByColumn(Column.TIME_QUEUED.value, Qt.SortOrder.DescendingOrder)
|
||||
|
||||
# Clear saved settings
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_ORDER)
|
||||
self.settings.settings.remove("")
|
||||
self.settings.end_group()
|
||||
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_WIDTHS)
|
||||
self.settings.settings.remove("")
|
||||
self.settings.end_group()
|
||||
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_SORT_STATE)
|
||||
self.settings.settings.remove("")
|
||||
self.settings.end_group()
|
||||
|
||||
# Save the reset state for visibility, widths, and sort
|
||||
self.save_column_visibility()
|
||||
self.save_column_widths()
|
||||
self.save_sort_state()
|
||||
|
||||
# Force a refresh of the table layout
|
||||
self.horizontalHeader().update()
|
||||
self.viewport().update()
|
||||
self.updateGeometry()
|
||||
|
||||
def reload_column_order_from_settings(self):
|
||||
"""Reload column order, width, and visibility from settings"""
|
||||
|
||||
# --- Load column visibility ---
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY)
|
||||
visibility_settings = {}
|
||||
for definition in column_definitions:
|
||||
vis = self.settings.settings.value(definition.id)
|
||||
if vis is not None:
|
||||
visibility_settings[definition.id] = str(vis).lower() not in ("0", "false", "no")
|
||||
self.settings.end_group()
|
||||
|
||||
# --- Load column widths ---
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_WIDTHS)
|
||||
width_settings = {}
|
||||
for definition in column_definitions:
|
||||
width = self.settings.settings.value(definition.id)
|
||||
if width is not None:
|
||||
try:
|
||||
width_settings[definition.id] = int(width)
|
||||
except Exception:
|
||||
pass
|
||||
self.settings.end_group()
|
||||
|
||||
# --- Load column order ---
|
||||
self.settings.begin_group(Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_ORDER)
|
||||
order_settings = {}
|
||||
for definition in column_definitions:
|
||||
pos = self.settings.settings.value(definition.id)
|
||||
if pos is not None:
|
||||
try:
|
||||
order_settings[definition.column.value] = int(pos)
|
||||
except Exception:
|
||||
pass
|
||||
self.settings.end_group()
|
||||
|
||||
# --- Apply visibility, widths, and order ---
|
||||
header = self.horizontalHeader()
|
||||
|
||||
# First, set visibility and width for each column
|
||||
for definition in column_definitions:
|
||||
is_visible = visibility_settings.get(definition.id, True)
|
||||
width = width_settings.get(definition.id, definition.width)
|
||||
self.setColumnHidden(definition.column.value, not is_visible)
|
||||
if width is not None:
|
||||
self.setColumnWidth(definition.column.value, max(width, 100))
|
||||
|
||||
# Then, apply column order
|
||||
# Build a list of (logical_index, visual_position) for ALL columns (including hidden ones)
|
||||
all_columns = [
|
||||
(definition.column.value, order_settings.get(definition.column.value, idx))
|
||||
for idx, definition in enumerate(column_definitions)
|
||||
]
|
||||
# Sort by saved visual position
|
||||
all_columns.sort(key=lambda x: x[1])
|
||||
|
||||
# Move sections to match the saved order
|
||||
for target_visual, (logical_index, _) in enumerate(all_columns):
|
||||
current_visual = header.visualIndex(logical_index)
|
||||
if current_visual != target_visual:
|
||||
header.moveSection(current_visual, target_visual)
|
||||
|
||||
def copy_selected_fields(self):
|
||||
selected_text = ""
|
||||
for row in self.selectionModel().selectedRows():
|
||||
row_index = row.row()
|
||||
file_name = self.model().data(self.model().index(row_index, Column.FILE.value))
|
||||
url = self.model().data(self.model().index(row_index, Column.URL.value))
|
||||
file_name = self.model().data(self.model().index(row_index, 3))
|
||||
url = self.model().data(self.model().index(row_index, 14))
|
||||
|
||||
selected_text += f"{file_name}{url}\n"
|
||||
|
||||
selected_text = selected_text.rstrip("\n")
|
||||
QApplication.clipboard().setText(selected_text)
|
||||
|
||||
def mouseDoubleClickEvent(self, event: QtGui.QMouseEvent) -> None:
|
||||
"""Override double-click to prevent default behavior when clicking on notes column"""
|
||||
index = self.indexAt(event.pos())
|
||||
if index.isValid() and index.column() == Column.NOTES.value:
|
||||
# Handle our custom double-click action without triggering default behavior
|
||||
self.on_double_click(index)
|
||||
event.accept()
|
||||
else:
|
||||
# For other columns, use default behavior
|
||||
super().mouseDoubleClickEvent(event)
|
||||
|
||||
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
|
||||
if event.key() == Qt.Key.Key_Return:
|
||||
self.return_clicked.emit()
|
||||
|
||||
if event.key() == Qt.Key.Key_Delete:
|
||||
if self.selectionModel().selectedRows():
|
||||
self.delete_requested.emit()
|
||||
return
|
||||
|
||||
if event.matches(QKeySequence.StandardKey.Copy):
|
||||
self.copy_selected_fields()
|
||||
return
|
||||
|
|
@ -613,195 +307,3 @@ class TranscriptionTasksTableWidget(QTableView):
|
|||
if hh == 0:
|
||||
return result
|
||||
return f"{hh}h {result}"
|
||||
|
||||
def on_rename_action(self):
|
||||
selected_rows = self.selectionModel().selectedRows()
|
||||
if not selected_rows:
|
||||
return
|
||||
|
||||
# Get the first selected transcription
|
||||
transcription = self.transcription(selected_rows[0])
|
||||
|
||||
# Get current name or fallback to file name
|
||||
current_name = transcription.name or (
|
||||
transcription.url if transcription.url
|
||||
else os.path.basename(transcription.file) if transcription.file
|
||||
else ""
|
||||
)
|
||||
|
||||
# Show input dialog
|
||||
from PyQt6.QtWidgets import QInputDialog
|
||||
new_name, ok = QInputDialog.getText(
|
||||
self,
|
||||
_("Rename Transcription"),
|
||||
_("Enter new name:"),
|
||||
text=current_name
|
||||
)
|
||||
|
||||
if ok and new_name.strip():
|
||||
# Update the transcription name
|
||||
from uuid import UUID
|
||||
self.transcription_service.update_transcription_name(
|
||||
UUID(transcription.id),
|
||||
new_name.strip()
|
||||
)
|
||||
self.refresh_all()
|
||||
|
||||
def on_notes_action(self):
|
||||
selected_rows = self.selectionModel().selectedRows()
|
||||
if not selected_rows:
|
||||
return
|
||||
|
||||
# Get the first selected transcription
|
||||
transcription = self.transcription(selected_rows[0])
|
||||
|
||||
# Show input dialog for notes
|
||||
from PyQt6.QtWidgets import QInputDialog
|
||||
current_notes = transcription.notes or ""
|
||||
new_notes, ok = QInputDialog.getMultiLineText(
|
||||
self,
|
||||
_("Notes"),
|
||||
_("Enter some relevant notes for this transcription:"),
|
||||
text=current_notes
|
||||
)
|
||||
|
||||
if ok:
|
||||
# Update the transcription notes
|
||||
from uuid import UUID
|
||||
self.transcription_service.update_transcription_notes(
|
||||
UUID(transcription.id),
|
||||
new_notes
|
||||
)
|
||||
self.refresh_all()
|
||||
|
||||
def on_restart_transcription_action(self):
|
||||
"""Restart transcription for failed or canceled tasks"""
|
||||
selected_rows = self.selectionModel().selectedRows()
|
||||
if not selected_rows:
|
||||
return
|
||||
|
||||
# Get the first selected transcription
|
||||
transcription = self.transcription(selected_rows[0])
|
||||
|
||||
# Check if the task can be restarted
|
||||
if transcription.status not in ["failed", "canceled"]:
|
||||
from PyQt6.QtWidgets import QMessageBox
|
||||
QMessageBox.information(
|
||||
self,
|
||||
_("Cannot Restart"),
|
||||
_("Only failed or canceled transcriptions can be restarted.")
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.transcription_service.reset_transcription_for_restart(UUID(transcription.id))
|
||||
self._restart_transcription_task(transcription)
|
||||
self.refresh_all()
|
||||
except Exception as e:
|
||||
from PyQt6.QtWidgets import QMessageBox
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
_("Error"),
|
||||
_("Failed to restart transcription: {}").format(str(e))
|
||||
)
|
||||
|
||||
def _restart_transcription_task(self, transcription):
|
||||
"""Create a new FileTranscriptionTask and add it to the queue worker"""
|
||||
from buzz.transcriber.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
TranscriptionOptions,
|
||||
FileTranscriptionOptions,
|
||||
Task
|
||||
)
|
||||
from buzz.model_loader import TranscriptionModel, ModelType
|
||||
from buzz.transcriber.transcriber import OutputFormat
|
||||
|
||||
# Recreate the transcription options from the database record
|
||||
from buzz.model_loader import WhisperModelSize
|
||||
|
||||
# Convert string whisper_model_size to enum if it exists
|
||||
whisper_model_size = None
|
||||
if transcription.whisper_model_size:
|
||||
try:
|
||||
whisper_model_size = WhisperModelSize(transcription.whisper_model_size)
|
||||
except ValueError:
|
||||
# If the stored value is invalid, use a default
|
||||
whisper_model_size = WhisperModelSize.TINY
|
||||
|
||||
transcription_options = TranscriptionOptions(
|
||||
language=transcription.language if transcription.language else None,
|
||||
task=Task(transcription.task) if transcription.task else Task.TRANSCRIBE,
|
||||
model=TranscriptionModel(
|
||||
model_type=ModelType(transcription.model_type) if transcription.model_type else ModelType.WHISPER,
|
||||
whisper_model_size=whisper_model_size,
|
||||
hugging_face_model_id=transcription.hugging_face_model_id
|
||||
),
|
||||
word_level_timings=transcription.word_level_timings == "1" if transcription.word_level_timings else False,
|
||||
extract_speech=transcription.extract_speech == "1" if transcription.extract_speech else False,
|
||||
initial_prompt="", # Not stored in database, use default
|
||||
openai_access_token="", # Not stored in database, use default
|
||||
enable_llm_translation=False, # Not stored in database, use default
|
||||
llm_prompt="", # Not stored in database, use default
|
||||
llm_model="" # Not stored in database, use default
|
||||
)
|
||||
|
||||
# Recreate the file transcription options
|
||||
output_formats = set()
|
||||
if transcription.export_formats:
|
||||
for format_str in transcription.export_formats.split(','):
|
||||
try:
|
||||
output_formats.add(OutputFormat(format_str.strip()))
|
||||
except ValueError:
|
||||
pass # Skip invalid formats
|
||||
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
url=transcription.url if transcription.url else None,
|
||||
output_formats=output_formats
|
||||
)
|
||||
|
||||
# Get the model path from the transcription options
|
||||
model_path = transcription_options.model.get_local_model_path()
|
||||
if model_path is None:
|
||||
# If model is not available locally, we need to download it
|
||||
from buzz.model_loader import ModelDownloader
|
||||
ModelDownloader(model=transcription_options.model).run()
|
||||
model_path = transcription_options.model.get_local_model_path()
|
||||
|
||||
if model_path is None:
|
||||
from PyQt6.QtWidgets import QMessageBox
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
_("Error"),
|
||||
_("Could not restart transcription: model not available and could not be downloaded.")
|
||||
)
|
||||
return
|
||||
|
||||
# Create the new task
|
||||
task = FileTranscriptionTask(
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
model_path=model_path,
|
||||
file_path=transcription.file if transcription.file else None,
|
||||
url=transcription.url if transcription.url else None,
|
||||
output_directory=transcription.output_folder if transcription.output_folder else None,
|
||||
source=FileTranscriptionTask.Source(transcription.source) if transcription.source else FileTranscriptionTask.Source.FILE_IMPORT,
|
||||
uid=UUID(transcription.id)
|
||||
)
|
||||
|
||||
# Add the task to the queue worker
|
||||
# We need to access the main window's transcriber worker
|
||||
# This is a bit of a hack, but it's the cleanest way given the current architecture
|
||||
main_window = self.parent()
|
||||
while main_window and not hasattr(main_window, 'transcriber_worker'):
|
||||
main_window = main_window.parent()
|
||||
|
||||
if main_window and hasattr(main_window, 'transcriber_worker'):
|
||||
main_window.transcriber_worker.add_task(task)
|
||||
else:
|
||||
# Fallback: show error if we can't find the transcriber worker
|
||||
from PyQt6.QtWidgets import QMessageBox
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
_("Error"),
|
||||
_("Could not restart transcription: transcriber worker not found.")
|
||||
)
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
from PyQt6.QtGui import QAction
|
||||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtWidgets import QWidget, QMenu, QFileDialog
|
||||
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
|
|
@ -18,48 +17,36 @@ class ExportTranscriptionMenu(QMenu):
|
|||
self,
|
||||
transcription: Transcription,
|
||||
transcription_service: TranscriptionService,
|
||||
has_translation: bool,
|
||||
translation: pyqtSignal,
|
||||
parent: QWidget | None = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.transcription = transcription
|
||||
self.transcription_service = transcription_service
|
||||
self.segments = []
|
||||
self.load_segments()
|
||||
|
||||
translation.connect(self.on_translation_available)
|
||||
|
||||
text_label = _("Text")
|
||||
translation_label = _("Translation")
|
||||
self.text_actions = [
|
||||
QAction(text=f"{output_format.value.upper()} - {text_label}", parent=self)
|
||||
for output_format in OutputFormat
|
||||
]
|
||||
self.translation_actions = [
|
||||
QAction(text=f"{output_format.value.upper()} - {translation_label}", parent=self)
|
||||
for output_format in OutputFormat
|
||||
]
|
||||
for action in self.translation_actions:
|
||||
action.setVisible(has_translation)
|
||||
actions = self.text_actions + self.translation_actions
|
||||
if self.segments and len(self.segments[0].translation) > 0:
|
||||
text_label = _("Text")
|
||||
translation_label = _("Translation")
|
||||
actions = [
|
||||
action
|
||||
for output_format in OutputFormat
|
||||
for action in [
|
||||
QAction(text=f"{output_format.value.upper()} - {text_label}", parent=self),
|
||||
QAction(text=f"{output_format.value.upper()} - {translation_label}", parent=self)
|
||||
]
|
||||
]
|
||||
else:
|
||||
actions = [
|
||||
QAction(text=output_format.value.upper(), parent=self)
|
||||
for output_format in OutputFormat
|
||||
]
|
||||
self.addActions(actions)
|
||||
self.triggered.connect(self.on_menu_triggered)
|
||||
|
||||
@staticmethod
|
||||
def extract_format_and_segment_key(action_text: str):
|
||||
parts = action_text.split('-')
|
||||
output_format = parts[0].strip()
|
||||
label = parts[1].strip() if len(parts) > 1 else None
|
||||
segment_key = 'translation' if label == _('Translation') else 'text'
|
||||
|
||||
return output_format, segment_key
|
||||
|
||||
def on_translation_available(self):
|
||||
for action in self.translation_actions:
|
||||
action.setVisible(True)
|
||||
|
||||
def on_menu_triggered(self, action: QAction):
|
||||
segments = [
|
||||
def load_segments(self):
|
||||
self.segments = [
|
||||
Segment(
|
||||
start=segment.start_time,
|
||||
end=segment.end_time,
|
||||
|
|
@ -69,9 +56,18 @@ class ExportTranscriptionMenu(QMenu):
|
|||
transcription_id=self.transcription.id_as_uuid
|
||||
)
|
||||
]
|
||||
@staticmethod
|
||||
def extract_format_and_segment_key(action_text: str):
|
||||
parts = action_text.split('-')
|
||||
output_format = parts[0].strip()
|
||||
label = parts[1].strip() if len(parts) > 1 else None
|
||||
segment_key = 'translation' if label == _('Translation') else 'text'
|
||||
|
||||
return output_format, segment_key
|
||||
|
||||
def on_menu_triggered(self, action: QAction):
|
||||
output_format_value, segment_key = self.extract_format_and_segment_key(action.text())
|
||||
output_format = OutputFormat(output_format_value.lower())
|
||||
output_format = OutputFormat[output_format_value]
|
||||
|
||||
default_path = self.transcription.get_output_file_path(
|
||||
output_format=output_format
|
||||
|
|
@ -87,9 +83,12 @@ class ExportTranscriptionMenu(QMenu):
|
|||
if output_file_path == "":
|
||||
return
|
||||
|
||||
# Reload segments in case they were resized
|
||||
self.load_segments()
|
||||
|
||||
write_output(
|
||||
path=output_file_path,
|
||||
segments=segments,
|
||||
segments=self.segments,
|
||||
output_format=output_format,
|
||||
segment_key=segment_key
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,800 +0,0 @@
|
|||
import re
|
||||
import os
|
||||
import logging
|
||||
import ssl
|
||||
import time
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
# Fix SSL certificate verification for bundled applications (macOS, Windows)
|
||||
# This must be done before importing libraries that download from Hugging Face
|
||||
try:
|
||||
import certifi
|
||||
os.environ.setdefault('REQUESTS_CA_BUNDLE', certifi.where())
|
||||
os.environ.setdefault('SSL_CERT_FILE', certifi.where())
|
||||
os.environ.setdefault('SSL_CERT_DIR', os.path.dirname(certifi.where()))
|
||||
# Also update the default SSL context for urllib
|
||||
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import faster_whisper
|
||||
import torch
|
||||
from PyQt6.QtMultimedia import QMediaPlayer, QAudioOutput
|
||||
from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QUrl, QTimer
|
||||
from PyQt6.QtGui import QFont
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QFormLayout,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QProgressBar,
|
||||
QPushButton,
|
||||
QCheckBox,
|
||||
QGroupBox,
|
||||
QSpacerItem,
|
||||
QSizePolicy,
|
||||
QLayout,
|
||||
)
|
||||
from buzz.locale import _
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
from buzz.db.service.transcription_service import TranscriptionService
|
||||
from buzz.paths import file_path_as_title
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
from buzz.transcriber.transcriber import Segment
|
||||
|
||||
|
||||
|
||||
def process_in_batches(
|
||||
items,
|
||||
process_func,
|
||||
batch_size=200,
|
||||
chunk_size=230,
|
||||
smaller_batch_size=100,
|
||||
exception_types=(AssertionError,),
|
||||
**process_func_kwargs
|
||||
):
|
||||
"""
|
||||
Process items in batches with automatic fallback to smaller batches on errors.
|
||||
|
||||
This is a generic batch processing function that can be used with any processing
|
||||
function that has chunk size limitations. It automatically retries with smaller
|
||||
batches when specified exceptions occur.
|
||||
|
||||
Args:
|
||||
items: List of items to process
|
||||
process_func: Callable that processes a batch. Should accept (batch, chunk_size, **kwargs)
|
||||
and return a list of results
|
||||
batch_size: Initial batch size (default: 200)
|
||||
chunk_size: Maximum chunk size for the processing function (default: 230)
|
||||
smaller_batch_size: Fallback batch size when errors occur (default: 100)
|
||||
exception_types: Tuple of exception types to catch and retry with smaller batches
|
||||
(default: (AssertionError,))
|
||||
**process_func_kwargs: Additional keyword arguments to pass to process_func
|
||||
|
||||
Returns:
|
||||
List of processed results (concatenated from all batches)
|
||||
|
||||
Example:
|
||||
>>> def my_predict(batch, chunk_size):
|
||||
... return [f"processed_{item}" for item in batch]
|
||||
>>> results = process_in_batches(
|
||||
... items=["a", "b", "c"],
|
||||
... process_func=my_predict,
|
||||
... batch_size=2
|
||||
... )
|
||||
"""
|
||||
all_results = []
|
||||
|
||||
for i in range(0, len(items), batch_size):
|
||||
batch = items[i:i + batch_size]
|
||||
try:
|
||||
batch_results = process_func(batch, chunk_size=min(chunk_size, len(batch)), **process_func_kwargs)
|
||||
all_results.extend(batch_results)
|
||||
except exception_types as e:
|
||||
# If batch still fails, try with even smaller chunks
|
||||
logging.warning(f"Batch processing failed, trying smaller chunks: {e}")
|
||||
for j in range(0, len(batch), smaller_batch_size):
|
||||
smaller_batch = batch[j:j + smaller_batch_size]
|
||||
smaller_results = process_func(smaller_batch, chunk_size=min(chunk_size, len(smaller_batch)), **process_func_kwargs)
|
||||
all_results.extend(smaller_results)
|
||||
|
||||
return all_results
|
||||
|
||||
SENTENCE_END = re.compile(r'.*[.!?。!?]')
|
||||
|
||||
class IdentificationWorker(QObject):
|
||||
finished = pyqtSignal(list)
|
||||
progress_update = pyqtSignal(str)
|
||||
error = pyqtSignal(str)
|
||||
|
||||
def __init__(self, transcription, transcription_service):
|
||||
super().__init__()
|
||||
self.transcription = transcription
|
||||
self.transcription_service = transcription_service
|
||||
self._is_cancelled = False
|
||||
|
||||
def cancel(self):
|
||||
"""Request cancellation of the worker."""
|
||||
self._is_cancelled = True
|
||||
|
||||
def get_transcript(self, audio, **kwargs) -> dict:
|
||||
buzz_segments = self.transcription_service.get_transcription_segments(
|
||||
transcription_id=self.transcription.id_as_uuid
|
||||
)
|
||||
|
||||
segments = []
|
||||
words = []
|
||||
text = ""
|
||||
for buzz_segment in buzz_segments:
|
||||
words.append({
|
||||
'word': buzz_segment.text + " ",
|
||||
'start': buzz_segment.start_time / 100,
|
||||
'end': buzz_segment.end_time / 100,
|
||||
})
|
||||
text += buzz_segment.text + " "
|
||||
|
||||
if SENTENCE_END.match(buzz_segment.text):
|
||||
segments.append({
|
||||
'text': text,
|
||||
'words': words
|
||||
})
|
||||
words = []
|
||||
text = ""
|
||||
|
||||
return {
|
||||
'language': self.transcription.language,
|
||||
'segments': segments
|
||||
}
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
from ctc_forced_aligner.ctc_forced_aligner import (
|
||||
generate_emissions,
|
||||
get_alignments,
|
||||
get_spans,
|
||||
load_alignment_model,
|
||||
postprocess_results,
|
||||
preprocess_text,
|
||||
)
|
||||
from whisper_diarization.helpers import (
|
||||
get_realigned_ws_mapping_with_punctuation,
|
||||
get_sentences_speaker_mapping,
|
||||
get_words_speaker_mapping,
|
||||
langs_to_iso,
|
||||
punct_model_langs,
|
||||
)
|
||||
from deepmultilingualpunctuation.deepmultilingualpunctuation import PunctuationModel
|
||||
from whisper_diarization.diarization import MSDDDiarizer
|
||||
except ImportError as e:
|
||||
logging.exception("Failed to import speaker identification libraries: %s", e)
|
||||
self.error.emit(
|
||||
_("Speaker identification is not available: failed to load required libraries.")
|
||||
+ f"\n\n{e}"
|
||||
)
|
||||
return
|
||||
|
||||
diarizer_model = None
|
||||
alignment_model = None
|
||||
|
||||
try:
|
||||
logging.debug("Speaker identification worker: Starting")
|
||||
self.progress_update.emit(_("1/8 Collecting transcripts"))
|
||||
|
||||
if self._is_cancelled:
|
||||
logging.debug("Speaker identification worker: Cancelled at step 1")
|
||||
return
|
||||
|
||||
# Step 1 - Get transcript
|
||||
# TODO - Add detected language to the transcript, detect and store separately in metadata
|
||||
# Will also be relevant for template parsing of transcript file names
|
||||
# - See diarize.py for example on how to get this info from whisper transcript, maybe other whisper models also have it
|
||||
language = self.transcription.language if self.transcription.language else "en"
|
||||
|
||||
segments = self.transcription_service.get_transcription_segments(
|
||||
transcription_id=self.transcription.id_as_uuid
|
||||
)
|
||||
|
||||
full_transcript = " ".join(segment.text for segment in segments)
|
||||
full_transcript = re.sub(r' {2,}', ' ', full_transcript)
|
||||
|
||||
if self._is_cancelled:
|
||||
logging.debug("Speaker identification worker: Cancelled at step 2")
|
||||
return
|
||||
|
||||
self.progress_update.emit(_("2/8 Loading audio"))
|
||||
audio_waveform = faster_whisper.decode_audio(self.transcription.file)
|
||||
|
||||
# Step 2 - Forced alignment
|
||||
force_cpu = os.getenv("BUZZ_FORCE_CPU", "false")
|
||||
use_cuda = torch.cuda.is_available() and force_cpu == "false"
|
||||
device = "cuda" if use_cuda else "cpu"
|
||||
torch_dtype = torch.float16 if use_cuda else torch.float32
|
||||
|
||||
logging.debug(f"Speaker identification worker: Using device={device}")
|
||||
|
||||
if self._is_cancelled:
|
||||
logging.debug("Speaker identification worker: Cancelled at step 3")
|
||||
return
|
||||
|
||||
self.progress_update.emit(_("3/8 Loading alignment model"))
|
||||
alignment_model = None
|
||||
alignment_tokenizer = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
alignment_model, alignment_tokenizer = load_alignment_model(
|
||||
device,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
if attempt < 2:
|
||||
logging.warning(
|
||||
f"Speaker identification: Failed to load alignment model "
|
||||
f"(attempt {attempt + 1}/3), retrying: {e}"
|
||||
)
|
||||
# On retry, try using cached models only (offline mode)
|
||||
# Set at runtime by modifying the library constants directly
|
||||
# (env vars are only read at import time)
|
||||
try:
|
||||
import huggingface_hub.constants
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = True
|
||||
logging.debug("Speaker identification: Enabled HF offline mode")
|
||||
except Exception as offline_err:
|
||||
logging.warning(f"Failed to set offline mode: {offline_err}")
|
||||
self.progress_update.emit(
|
||||
_("3/8 Loading alignment model (retrying with cache...)")
|
||||
)
|
||||
time.sleep(2 ** attempt) # 1s, 2s backoff
|
||||
else:
|
||||
raise RuntimeError(
|
||||
_("Failed to load alignment model. "
|
||||
"Please check your internet connection and try again.")
|
||||
) from e
|
||||
|
||||
if self._is_cancelled:
|
||||
logging.debug("Speaker identification worker: Cancelled at step 4")
|
||||
return
|
||||
|
||||
self.progress_update.emit(_("4/8 Processing audio"))
|
||||
logging.debug("Speaker identification worker: Generating emissions")
|
||||
emissions, stride = generate_emissions(
|
||||
alignment_model,
|
||||
torch.from_numpy(audio_waveform)
|
||||
.to(alignment_model.dtype)
|
||||
.to(alignment_model.device),
|
||||
batch_size=1 if device == "cpu" else 8,
|
||||
)
|
||||
logging.debug("Speaker identification worker: Emissions generated")
|
||||
|
||||
# Clean up alignment model
|
||||
del alignment_model
|
||||
alignment_model = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self._is_cancelled:
|
||||
logging.debug("Speaker identification worker: Cancelled at step 5")
|
||||
return
|
||||
|
||||
self.progress_update.emit(_("5/8 Preparing transcripts"))
|
||||
tokens_starred, text_starred = preprocess_text(
|
||||
full_transcript,
|
||||
romanize=True,
|
||||
language=langs_to_iso[language],
|
||||
)
|
||||
|
||||
segments, scores, blank_token = get_alignments(
|
||||
emissions,
|
||||
tokens_starred,
|
||||
alignment_tokenizer,
|
||||
)
|
||||
|
||||
spans = get_spans(tokens_starred, segments, blank_token)
|
||||
|
||||
word_timestamps = postprocess_results(text_starred, spans, stride, scores)
|
||||
|
||||
if self._is_cancelled:
|
||||
logging.debug("Speaker identification worker: Cancelled at step 6")
|
||||
return
|
||||
|
||||
# Step 3 - Diarization
|
||||
self.progress_update.emit(_("6/8 Identifying speakers"))
|
||||
|
||||
# Silence NeMo's verbose logging
|
||||
logging.getLogger("nemo_logging").setLevel(logging.ERROR)
|
||||
try:
|
||||
# Also try to silence NeMo's internal logging system
|
||||
from nemo.utils import logging as nemo_logging
|
||||
nemo_logging.setLevel(logging.ERROR)
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
logging.debug("Speaker identification worker: Creating diarizer model")
|
||||
diarizer_model = MSDDDiarizer(device)
|
||||
logging.debug("Speaker identification worker: Running diarization (this may take a while on CPU)")
|
||||
speaker_ts = diarizer_model.diarize(torch.from_numpy(audio_waveform).unsqueeze(0))
|
||||
logging.debug("Speaker identification worker: Diarization complete")
|
||||
|
||||
if self._is_cancelled:
|
||||
logging.debug("Speaker identification worker: Cancelled after diarization")
|
||||
return
|
||||
|
||||
# Clean up diarizer model immediately after use
|
||||
del diarizer_model
|
||||
diarizer_model = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self._is_cancelled:
|
||||
logging.debug("Speaker identification worker: Cancelled at step 7")
|
||||
return
|
||||
|
||||
# Step 4 - Reading timestamps <> Speaker Labels mapping
|
||||
self.progress_update.emit(_("7/8 Mapping speakers to transcripts"))
|
||||
|
||||
wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")
|
||||
|
||||
if language in punct_model_langs:
|
||||
# restoring punctuation in the transcript to help realign the sentences
|
||||
punct_model = PunctuationModel(model="kredor/punctuate-all")
|
||||
|
||||
words_list = list(map(lambda x: x["word"], wsm))
|
||||
|
||||
# Process in batches to avoid chunk size errors
|
||||
def predict_wrapper(batch, chunk_size, **kwargs):
|
||||
return punct_model.predict(batch, chunk_size=chunk_size)
|
||||
|
||||
labled_words = process_in_batches(
|
||||
items=words_list,
|
||||
process_func=predict_wrapper
|
||||
)
|
||||
|
||||
ending_puncts = ".?!。!?"
|
||||
model_puncts = ".,;:!?。!?"
|
||||
|
||||
# We don't want to punctuate U.S.A. with a period. Right?
|
||||
is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
|
||||
|
||||
for word_dict, labeled_tuple in zip(wsm, labled_words):
|
||||
word = word_dict["word"]
|
||||
if (
|
||||
word
|
||||
and labeled_tuple[1] in ending_puncts
|
||||
and (word[-1] not in model_puncts or is_acronym(word))
|
||||
):
|
||||
word += labeled_tuple[1]
|
||||
if word.endswith(".."):
|
||||
word = word.rstrip(".")
|
||||
word_dict["word"] = word
|
||||
|
||||
else:
|
||||
logging.warning(
|
||||
f"Punctuation restoration is not available for {language} language."
|
||||
" Using the original punctuation."
|
||||
)
|
||||
|
||||
wsm = get_realigned_ws_mapping_with_punctuation(wsm)
|
||||
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
|
||||
|
||||
logging.debug("Speaker identification worker: Finished successfully")
|
||||
self.progress_update.emit(_("8/8 Identification done"))
|
||||
self.finished.emit(ssm)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Speaker identification worker: Error - {e}", exc_info=True)
|
||||
self.progress_update.emit(_("0/0 Error identifying speakers"))
|
||||
self.error.emit(str(e))
|
||||
# Emit empty list so the UI can reset properly
|
||||
self.finished.emit([])
|
||||
|
||||
finally:
|
||||
# Ensure cleanup happens regardless of how we exit
|
||||
logging.debug("Speaker identification worker: Cleaning up resources")
|
||||
if diarizer_model is not None:
|
||||
try:
|
||||
del diarizer_model
|
||||
except Exception:
|
||||
pass
|
||||
if alignment_model is not None:
|
||||
try:
|
||||
del alignment_model
|
||||
except Exception:
|
||||
pass
|
||||
torch.cuda.empty_cache()
|
||||
# Reset offline mode so it doesn't affect other operations
|
||||
try:
|
||||
import huggingface_hub.constants
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class SpeakerIdentificationWidget(QWidget):
|
||||
resize_button_clicked = pyqtSignal()
|
||||
transcription: Transcription
|
||||
settings = Settings()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transcription: Transcription,
|
||||
transcription_service: TranscriptionService,
|
||||
parent: Optional["QWidget"] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget,
|
||||
transcriptions_updated_signal: Optional[pyqtSignal] = None,
|
||||
) -> None:
|
||||
super().__init__(parent, flags)
|
||||
self.transcription = transcription
|
||||
self.transcription_service = transcription_service
|
||||
self.transcriptions_updated_signal = transcriptions_updated_signal
|
||||
|
||||
self.identification_result = None
|
||||
|
||||
self.thread = None
|
||||
self.worker = None
|
||||
self.needs_layout_update = False
|
||||
|
||||
self.setMinimumWidth(650)
|
||||
self.setMinimumHeight(400)
|
||||
|
||||
self.setWindowTitle(file_path_as_title(transcription.file))
|
||||
|
||||
layout = QFormLayout(self)
|
||||
layout.setSizeConstraint(QLayout.SizeConstraint.SetMinAndMaxSize)
|
||||
|
||||
# Step 1: Identify speakers
|
||||
step_1_label = QLabel(_("Step 1: Identify speakers"), self)
|
||||
font = step_1_label.font()
|
||||
font.setWeight(QFont.Weight.Bold)
|
||||
step_1_label.setFont(font)
|
||||
layout.addRow(step_1_label)
|
||||
|
||||
step_1_group_box = QGroupBox(self)
|
||||
step_1_group_box.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||||
step_1_layout = QVBoxLayout(step_1_group_box)
|
||||
|
||||
self.step_1_row = QHBoxLayout()
|
||||
|
||||
self.step_1_button = QPushButton(_("Identify"))
|
||||
self.step_1_button.setMinimumWidth(200)
|
||||
self.step_1_button.clicked.connect(self.on_identify_button_clicked)
|
||||
|
||||
self.cancel_button = QPushButton(_("Cancel"))
|
||||
self.cancel_button.setMinimumWidth(200)
|
||||
self.cancel_button.setVisible(False)
|
||||
self.cancel_button.clicked.connect(self.on_cancel_button_clicked)
|
||||
|
||||
# Progress container with label and bar
|
||||
progress_container = QVBoxLayout()
|
||||
|
||||
self.progress_label = QLabel(self)
|
||||
if os.path.isfile(self.transcription.file):
|
||||
self.progress_label.setText(_("Ready to identify speakers"))
|
||||
else:
|
||||
self.progress_label.setText(_("Audio file not found"))
|
||||
self.step_1_button.setEnabled(False)
|
||||
|
||||
self.progress_bar = QProgressBar(self)
|
||||
self.progress_bar.setMinimumWidth(400)
|
||||
self.progress_bar.setRange(0, 8)
|
||||
self.progress_bar.setValue(0)
|
||||
|
||||
progress_container.addWidget(self.progress_label)
|
||||
progress_container.addWidget(self.progress_bar)
|
||||
|
||||
self.step_1_row.addLayout(progress_container)
|
||||
|
||||
button_container = QVBoxLayout()
|
||||
button_container.addWidget(self.step_1_button)
|
||||
button_container.addWidget(self.cancel_button)
|
||||
self.step_1_row.addLayout(button_container)
|
||||
|
||||
step_1_layout.addLayout(self.step_1_row)
|
||||
|
||||
layout.addRow(step_1_group_box)
|
||||
|
||||
# Spacer
|
||||
spacer = QSpacerItem(0, 10, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Fixed)
|
||||
layout.addItem(spacer)
|
||||
|
||||
# Step 2: Name speakers
|
||||
step_2_label = QLabel(_("Step 2: Name speakers"), self)
|
||||
font = step_2_label.font()
|
||||
font.setWeight(QFont.Weight.Bold)
|
||||
step_2_label.setFont(font)
|
||||
layout.addRow(step_2_label)
|
||||
|
||||
self.step_2_group_box = QGroupBox(self)
|
||||
self.step_2_group_box.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||||
self.step_2_group_box.setEnabled(False)
|
||||
step_2_layout = QVBoxLayout(self.step_2_group_box)
|
||||
|
||||
self.speaker_preview_row = QVBoxLayout()
|
||||
|
||||
self.speaker_0_input = LineEdit("Speaker 0", self)
|
||||
|
||||
self.speaker_0_preview_button = QPushButton(_("Play sample"))
|
||||
self.speaker_0_preview_button.setMinimumWidth(200)
|
||||
self.speaker_0_preview_button.clicked.connect(lambda: self.on_speaker_preview("Speaker 0"))
|
||||
|
||||
speaker_0_layout = QHBoxLayout()
|
||||
speaker_0_layout.addWidget(self.speaker_0_input)
|
||||
speaker_0_layout.addWidget(self.speaker_0_preview_button)
|
||||
|
||||
self.speaker_preview_row.addLayout(speaker_0_layout)
|
||||
|
||||
step_2_layout.addLayout(self.speaker_preview_row)
|
||||
|
||||
layout.addRow(self.step_2_group_box)
|
||||
|
||||
# Save button
|
||||
self.merge_speaker_sentences = QCheckBox(_("Merge speaker sentences"))
|
||||
self.merge_speaker_sentences.setChecked(True)
|
||||
self.merge_speaker_sentences.setEnabled(False)
|
||||
self.merge_speaker_sentences.setMinimumWidth(250)
|
||||
|
||||
self.save_button = QPushButton(_("Save"))
|
||||
self.save_button.setEnabled(False)
|
||||
self.save_button.clicked.connect(self.on_save_button_clicked)
|
||||
|
||||
layout.addRow(self.merge_speaker_sentences)
|
||||
layout.addRow(self.save_button)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
# Invisible preview player
|
||||
url = QUrl.fromLocalFile(self.transcription.file)
|
||||
self.player = QMediaPlayer()
|
||||
self.audio_output = QAudioOutput()
|
||||
self.player.setAudioOutput(self.audio_output)
|
||||
self.player.setSource(url)
|
||||
self.player_timer = None
|
||||
|
||||
def on_identify_button_clicked(self):
|
||||
self.step_1_button.setEnabled(False)
|
||||
self.step_1_button.setVisible(False)
|
||||
self.cancel_button.setVisible(True)
|
||||
|
||||
# Clean up any existing thread before starting a new one
|
||||
self._cleanup_thread()
|
||||
|
||||
logging.debug("Speaker identification: Starting identification thread")
|
||||
|
||||
self.thread = QThread()
|
||||
self.worker = IdentificationWorker(
|
||||
self.transcription,
|
||||
self.transcription_service
|
||||
)
|
||||
self.worker.moveToThread(self.thread)
|
||||
self.thread.started.connect(self.worker.run)
|
||||
self.worker.finished.connect(self._on_thread_finished)
|
||||
self.worker.progress_update.connect(self.on_progress_update)
|
||||
self.worker.error.connect(self.on_identification_error)
|
||||
|
||||
self.thread.start()
|
||||
|
||||
def on_cancel_button_clicked(self):
|
||||
"""Handle cancel button click."""
|
||||
logging.debug("Speaker identification: Cancel requested by user")
|
||||
self.cancel_button.setEnabled(False)
|
||||
self.progress_label.setText(_("Cancelling..."))
|
||||
self._cleanup_thread()
|
||||
self._reset_buttons()
|
||||
self.progress_label.setText(_("Cancelled"))
|
||||
self.progress_bar.setValue(0)
|
||||
|
||||
def _reset_buttons(self):
|
||||
"""Reset identify/cancel buttons to initial state."""
|
||||
self.step_1_button.setVisible(True)
|
||||
self.step_1_button.setEnabled(True)
|
||||
self.cancel_button.setVisible(False)
|
||||
self.cancel_button.setEnabled(True)
|
||||
|
||||
def _on_thread_finished(self, result):
|
||||
"""Handle thread completion and cleanup."""
|
||||
logging.debug("Speaker identification: Thread finished")
|
||||
if self.thread is not None:
|
||||
self.thread.quit()
|
||||
self.thread.wait(5000)
|
||||
self._reset_buttons()
|
||||
self.on_identification_finished(result)
|
||||
|
||||
def on_identification_error(self, error_message):
|
||||
"""Handle identification error."""
|
||||
logging.error(f"Speaker identification error: {error_message}")
|
||||
self._reset_buttons()
|
||||
self.progress_bar.setValue(0)
|
||||
|
||||
def on_progress_update(self, progress):
|
||||
self.progress_label.setText(progress)
|
||||
|
||||
progress_value = 0
|
||||
if progress and progress[0].isdigit():
|
||||
progress_value = int(progress[0])
|
||||
self.progress_bar.setValue(progress_value)
|
||||
else:
|
||||
logging.error(f"Invalid progress format: {progress}")
|
||||
|
||||
if progress_value == 8:
|
||||
self.step_2_group_box.setEnabled(True)
|
||||
self.merge_speaker_sentences.setEnabled(True)
|
||||
self.save_button.setEnabled(True)
|
||||
|
||||
def on_identification_finished(self, result):
|
||||
self.identification_result = result
|
||||
|
||||
# Handle empty results (error case)
|
||||
if not result:
|
||||
logging.debug("Speaker identification: Empty result received")
|
||||
return
|
||||
|
||||
unique_speakers = {entry['speaker'] for entry in result}
|
||||
|
||||
while self.speaker_preview_row.count():
|
||||
item = self.speaker_preview_row.takeAt(0)
|
||||
widget = item.widget()
|
||||
if widget:
|
||||
widget.deleteLater()
|
||||
else:
|
||||
layout = item.layout()
|
||||
if layout:
|
||||
while layout.count():
|
||||
sub_item = layout.takeAt(0)
|
||||
sub_widget = sub_item.widget()
|
||||
if sub_widget:
|
||||
sub_widget.deleteLater()
|
||||
|
||||
for speaker in sorted(unique_speakers):
|
||||
speaker_input = LineEdit(speaker, self)
|
||||
speaker_input.setMinimumWidth(200)
|
||||
|
||||
speaker_preview_button = QPushButton(_("Play sample"))
|
||||
speaker_preview_button.setMinimumWidth(200)
|
||||
speaker_preview_button.clicked.connect(lambda checked, s=speaker: self.on_speaker_preview(s))
|
||||
|
||||
speaker_layout = QHBoxLayout()
|
||||
speaker_layout.addWidget(speaker_input)
|
||||
speaker_layout.addWidget(speaker_preview_button)
|
||||
|
||||
self.speaker_preview_row.addLayout(speaker_layout)
|
||||
|
||||
# Trigger layout update to properly size the new widgets
|
||||
self.layout().activate()
|
||||
self.adjustSize()
|
||||
# Schedule update if window is minimized
|
||||
self.needs_layout_update = True
|
||||
|
||||
def on_speaker_preview(self, speaker_id):
|
||||
if self.player_timer:
|
||||
self.player_timer.stop()
|
||||
|
||||
speaker_records = [record for record in self.identification_result if record['speaker'] == speaker_id]
|
||||
|
||||
if speaker_records:
|
||||
random_record = random.choice(speaker_records)
|
||||
|
||||
start_time = random_record['start_time']
|
||||
end_time = random_record['end_time']
|
||||
|
||||
self.player.setPosition(int(start_time))
|
||||
self.player.play()
|
||||
|
||||
self.player_timer = QTimer(self)
|
||||
self.player_timer.setSingleShot(True)
|
||||
self.player_timer.timeout.connect(self.player.stop)
|
||||
self.player_timer.start(min(end_time, 10 * 1000)) # 10 seconds
|
||||
|
||||
def on_save_button_clicked(self):
|
||||
speaker_names = []
|
||||
for i in range(self.speaker_preview_row.count()):
|
||||
item = self.speaker_preview_row.itemAt(i)
|
||||
if item.layout():
|
||||
for j in range(item.layout().count()):
|
||||
sub_item = item.layout().itemAt(j)
|
||||
widget = sub_item.widget()
|
||||
if isinstance(widget, LineEdit):
|
||||
speaker_names.append(widget.text())
|
||||
|
||||
unique_speakers = {entry['speaker'] for entry in self.identification_result}
|
||||
original_speakers = sorted(unique_speakers)
|
||||
speaker_mapping = dict(zip(original_speakers, speaker_names))
|
||||
|
||||
segments = []
|
||||
if self.merge_speaker_sentences.isChecked():
|
||||
previous_segment = None
|
||||
|
||||
for entry in self.identification_result:
|
||||
speaker_name = speaker_mapping.get(entry['speaker'], entry['speaker'])
|
||||
|
||||
if previous_segment and previous_segment['speaker'] == speaker_name:
|
||||
previous_segment['end_time'] = entry['end_time']
|
||||
previous_segment['text'] += " " + entry['text']
|
||||
else:
|
||||
if previous_segment:
|
||||
segment = Segment(
|
||||
start=previous_segment['start_time'],
|
||||
end=previous_segment['end_time'],
|
||||
text=f"{previous_segment['speaker']}: {previous_segment['text']}"
|
||||
)
|
||||
segments.append(segment)
|
||||
previous_segment = {
|
||||
'start_time': entry['start_time'],
|
||||
'end_time': entry['end_time'],
|
||||
'speaker': speaker_name,
|
||||
'text': entry['text']
|
||||
}
|
||||
|
||||
if previous_segment:
|
||||
segment = Segment(
|
||||
start=previous_segment['start_time'],
|
||||
end=previous_segment['end_time'],
|
||||
text=f"{previous_segment['speaker']}: {previous_segment['text']}"
|
||||
)
|
||||
segments.append(segment)
|
||||
else:
|
||||
for entry in self.identification_result:
|
||||
speaker_name = speaker_mapping.get(entry['speaker'], entry['speaker'])
|
||||
segment = Segment(
|
||||
start=entry['start_time'],
|
||||
end=entry['end_time'],
|
||||
text=f"{speaker_name}: {entry['text']}"
|
||||
)
|
||||
segments.append(segment)
|
||||
|
||||
new_transcript_id = self.transcription_service.copy_transcription(
|
||||
self.transcription.id_as_uuid
|
||||
)
|
||||
|
||||
self.transcription_service.update_transcription_as_completed(new_transcript_id, segments)
|
||||
|
||||
# TODO - See if we can get rows in the transcription viewer to be of variable height
|
||||
# If text is longer they should expand
|
||||
if self.transcriptions_updated_signal:
|
||||
self.transcriptions_updated_signal.emit(new_transcript_id)
|
||||
|
||||
self.player.stop()
|
||||
|
||||
if self.player_timer:
|
||||
self.player_timer.stop()
|
||||
|
||||
self.close()
|
||||
|
||||
def changeEvent(self, event):
|
||||
super().changeEvent(event)
|
||||
|
||||
# Handle window activation (restored from minimized or brought to front)
|
||||
if self.needs_layout_update:
|
||||
self.layout().activate()
|
||||
self.adjustSize()
|
||||
self.needs_layout_update = False
|
||||
|
||||
def closeEvent(self, event):
|
||||
self.hide()
|
||||
|
||||
# Stop media player
|
||||
self.player.stop()
|
||||
if self.player_timer:
|
||||
self.player_timer.stop()
|
||||
|
||||
# Clean up thread if running
|
||||
self._cleanup_thread()
|
||||
|
||||
super().closeEvent(event)
|
||||
|
||||
def _cleanup_thread(self):
|
||||
"""Properly clean up the worker thread."""
|
||||
if self.worker is not None:
|
||||
# Request cancellation first
|
||||
self.worker.cancel()
|
||||
|
||||
if self.thread is not None and self.thread.isRunning():
|
||||
logging.debug("Speaker identification: Stopping running thread")
|
||||
self.thread.quit()
|
||||
if not self.thread.wait(10000): # Wait up to 10 seconds
|
||||
logging.warning("Speaker identification: Thread did not quit, terminating")
|
||||
self.thread.terminate()
|
||||
if not self.thread.wait(2000):
|
||||
logging.error("Speaker identification: Thread failed to terminate")
|
||||
|
||||
self.thread = None
|
||||
self.worker = None
|
||||
|
|
@ -1,444 +0,0 @@
|
|||
import re
|
||||
import os
|
||||
import logging
|
||||
import stable_whisper
|
||||
import srt
|
||||
from pathlib import Path
|
||||
from srt_equalizer import srt_equalizer
|
||||
from typing import Optional
|
||||
from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal
|
||||
from PyQt6.QtGui import QFont
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QFormLayout,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QSpinBox,
|
||||
QPushButton,
|
||||
QCheckBox,
|
||||
QGroupBox,
|
||||
QSpacerItem,
|
||||
QSizePolicy,
|
||||
)
|
||||
from buzz.locale import _, languages
|
||||
from buzz import whisper_audio
|
||||
from buzz.db.entity.transcription import Transcription
|
||||
from buzz.db.service.transcription_service import TranscriptionService
|
||||
from buzz.paths import file_path_as_title
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.widgets.line_edit import LineEdit
|
||||
from buzz.transcriber.transcriber import Segment
|
||||
from buzz.widgets.preferences_dialog.models.file_transcription_preferences import (
|
||||
FileTranscriptionPreferences,
|
||||
)
|
||||
|
||||
|
||||
SENTENCE_END = re.compile(r'.*[.!?。!?]')
|
||||
|
||||
# Languages that don't use spaces between words
|
||||
NON_SPACE_LANGUAGES = {"zh", "ja", "th", "lo", "km", "my"}
|
||||
|
||||
class TranscriptionWorker(QObject):
|
||||
finished = pyqtSignal(list)
|
||||
|
||||
def __init__(self, transcription, transcription_options, transcription_service, regroup_string: str):
|
||||
super().__init__()
|
||||
self.transcription = transcription
|
||||
self.transcription_options = transcription_options
|
||||
self.transcription_service = transcription_service
|
||||
self.regroup_string = regroup_string
|
||||
|
||||
def get_transcript(self, audio, **kwargs) -> dict:
|
||||
buzz_segments = self.transcription_service.get_transcription_segments(
|
||||
transcription_id=self.transcription.id_as_uuid
|
||||
)
|
||||
|
||||
# Check if the language uses spaces between words
|
||||
language = self.transcription.language or ""
|
||||
is_non_space_language = language in NON_SPACE_LANGUAGES
|
||||
|
||||
# For non-space languages, don't add spaces between words
|
||||
separator = "" if is_non_space_language else " "
|
||||
|
||||
segments = []
|
||||
words = []
|
||||
text = ""
|
||||
for buzz_segment in buzz_segments:
|
||||
words.append({
|
||||
'word': buzz_segment.text + separator,
|
||||
'start': buzz_segment.start_time / 100,
|
||||
'end': buzz_segment.end_time / 100,
|
||||
})
|
||||
text += buzz_segment.text + separator
|
||||
|
||||
if SENTENCE_END.match(buzz_segment.text):
|
||||
segments.append({
|
||||
'text': text,
|
||||
'words': words
|
||||
})
|
||||
words = []
|
||||
text = ""
|
||||
|
||||
# Add any remaining words that weren't terminated by sentence-ending punctuation
|
||||
if words:
|
||||
segments.append({
|
||||
'text': text,
|
||||
'words': words
|
||||
})
|
||||
|
||||
return {
|
||||
'language': self.transcription.language,
|
||||
'segments': segments
|
||||
}
|
||||
|
||||
def run(self):
|
||||
transcription_file = self.transcription.file
|
||||
transcription_file_exists = os.path.exists(transcription_file)
|
||||
|
||||
transcription_file_path = Path(transcription_file)
|
||||
speech_path = transcription_file_path.with_name(f"{transcription_file_path.stem}_speech.mp3")
|
||||
if self.transcription_options.extract_speech and os.path.exists(speech_path):
|
||||
transcription_file = str(speech_path)
|
||||
transcription_file_exists = True
|
||||
# TODO - Fix VAD and Silence suppression that fails to work/download Vad model in compilded form on Mac and Windows
|
||||
|
||||
try:
|
||||
result = stable_whisper.transcribe_any(
|
||||
self.get_transcript,
|
||||
audio = whisper_audio.load_audio(transcription_file),
|
||||
input_sr=whisper_audio.SAMPLE_RATE,
|
||||
# vad=transcription_file_exists,
|
||||
# suppress_silence=transcription_file_exists,
|
||||
vad=False,
|
||||
suppress_silence=False,
|
||||
regroup=self.regroup_string,
|
||||
check_sorted=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in TranscriptionWorker: {e}")
|
||||
return
|
||||
|
||||
segments = []
|
||||
for segment in result.segments:
|
||||
segments.append(
|
||||
Segment(
|
||||
start=int(segment.start * 100),
|
||||
end=int(segment.end * 100),
|
||||
text=segment.text
|
||||
)
|
||||
)
|
||||
|
||||
self.finished.emit(segments)
|
||||
|
||||
|
||||
class TranscriptionResizerWidget(QWidget):
|
||||
resize_button_clicked = pyqtSignal()
|
||||
transcription: Transcription
|
||||
settings = Settings()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transcription: Transcription,
|
||||
transcription_service: TranscriptionService,
|
||||
parent: Optional["QWidget"] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget,
|
||||
transcriptions_updated_signal: Optional[pyqtSignal] = None,
|
||||
) -> None:
|
||||
super().__init__(parent, flags)
|
||||
self.transcription = transcription
|
||||
self.transcription_service = transcription_service
|
||||
self.transcriptions_updated_signal = transcriptions_updated_signal
|
||||
|
||||
self.new_transcript_id = None
|
||||
self.thread = None
|
||||
self.worker = None
|
||||
|
||||
self.setMinimumWidth(600)
|
||||
self.setMinimumHeight(300)
|
||||
|
||||
self.setWindowTitle(file_path_as_title(transcription.file))
|
||||
|
||||
preferences = self.load_preferences()
|
||||
|
||||
(
|
||||
self.transcription_options,
|
||||
self.file_transcription_options,
|
||||
) = preferences.to_transcription_options(
|
||||
openai_access_token=''
|
||||
)
|
||||
|
||||
layout = QFormLayout(self)
|
||||
|
||||
# Extend segment endings
|
||||
extend_label = QLabel(_("Extend end time"), self)
|
||||
font = extend_label.font()
|
||||
font.setWeight(QFont.Weight.Bold)
|
||||
extend_label.setFont(font)
|
||||
layout.addRow(extend_label)
|
||||
|
||||
extend_group_box = QGroupBox(self)
|
||||
extend_layout = QVBoxLayout(extend_group_box)
|
||||
|
||||
self.extend_row = QHBoxLayout()
|
||||
|
||||
self.extend_amount_label = QLabel(_("Extend endings by up to (seconds)"), self)
|
||||
|
||||
self.extend_amount_input = LineEdit("0.2", self)
|
||||
self.extend_amount_input.setMaximumWidth(60)
|
||||
|
||||
self.extend_button = QPushButton(_("Extend endings"))
|
||||
self.extend_button.clicked.connect(self.on_extend_button_clicked)
|
||||
|
||||
self.extend_row.addWidget(self.extend_amount_label)
|
||||
self.extend_row.addWidget(self.extend_amount_input)
|
||||
self.extend_row.addWidget(self.extend_button)
|
||||
|
||||
extend_layout.addLayout(self.extend_row)
|
||||
|
||||
layout.addRow(extend_group_box)
|
||||
|
||||
# Spacer
|
||||
spacer1 = QSpacerItem(0, 10, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Fixed)
|
||||
layout.addItem(spacer1)
|
||||
|
||||
# Resize longer subtitles
|
||||
resize_label = QLabel(_("Resize Options"), self)
|
||||
font = resize_label.font()
|
||||
font.setWeight(QFont.Weight.Bold)
|
||||
resize_label.setFont(font)
|
||||
layout.addRow(resize_label)
|
||||
|
||||
resize_group_box = QGroupBox(self)
|
||||
resize_layout = QVBoxLayout(resize_group_box)
|
||||
|
||||
self.resize_row = QHBoxLayout()
|
||||
|
||||
self.desired_subtitle_length_label = QLabel(_("Desired subtitle length"), self)
|
||||
|
||||
self.target_chars_spin_box = QSpinBox(self)
|
||||
self.target_chars_spin_box.setMinimum(1)
|
||||
self.target_chars_spin_box.setMaximum(100)
|
||||
self.target_chars_spin_box.setValue(42)
|
||||
|
||||
self.resize_button = QPushButton(_("Resize"))
|
||||
self.resize_button.clicked.connect(self.on_resize_button_clicked)
|
||||
|
||||
self.resize_row.addWidget(self.desired_subtitle_length_label)
|
||||
self.resize_row.addWidget(self.target_chars_spin_box)
|
||||
self.resize_row.addWidget(self.resize_button)
|
||||
|
||||
resize_layout.addLayout(self.resize_row)
|
||||
|
||||
resize_group_box.setEnabled(self.transcription.word_level_timings != 1)
|
||||
if self.transcription.word_level_timings == 1:
|
||||
resize_group_box.setToolTip(_("Available only if word level timings were disabled during transcription"))
|
||||
|
||||
layout.addRow(resize_group_box)
|
||||
|
||||
# Spacer
|
||||
spacer2 = QSpacerItem(0, 10, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Fixed)
|
||||
layout.addItem(spacer2)
|
||||
|
||||
# Merge words into subtitles
|
||||
merge_options_label = QLabel(_("Merge Options"), self)
|
||||
font = merge_options_label.font()
|
||||
font.setWeight(QFont.Weight.Bold)
|
||||
merge_options_label.setFont(font)
|
||||
layout.addRow(merge_options_label)
|
||||
|
||||
merge_options_group_box = QGroupBox(self)
|
||||
merge_options_layout = QVBoxLayout(merge_options_group_box)
|
||||
|
||||
self.merge_options_row = QVBoxLayout()
|
||||
|
||||
self.merge_by_gap = QCheckBox(_("Merge by gap"))
|
||||
self.merge_by_gap.setChecked(True)
|
||||
self.merge_by_gap.setMinimumWidth(250)
|
||||
self.merge_by_gap_input = LineEdit("0.2", self)
|
||||
merge_by_gap_layout = QHBoxLayout()
|
||||
merge_by_gap_layout.addWidget(self.merge_by_gap)
|
||||
merge_by_gap_layout.addWidget(self.merge_by_gap_input)
|
||||
|
||||
self.split_by_punctuation = QCheckBox(_("Split by punctuation"))
|
||||
self.split_by_punctuation.setChecked(True)
|
||||
self.split_by_punctuation.setMinimumWidth(250)
|
||||
self.split_by_punctuation_input = LineEdit(".* /./. /。/?/? /?/!/! /!/,/, ", self)
|
||||
split_by_punctuation_layout = QHBoxLayout()
|
||||
split_by_punctuation_layout.addWidget(self.split_by_punctuation)
|
||||
split_by_punctuation_layout.addWidget(self.split_by_punctuation_input)
|
||||
|
||||
self.split_by_max_length = QCheckBox(_("Split by max length"))
|
||||
self.split_by_max_length.setChecked(True)
|
||||
self.split_by_max_length.setMinimumWidth(250)
|
||||
self.split_by_max_length_input = LineEdit("42", self)
|
||||
split_by_max_length_layout = QHBoxLayout()
|
||||
split_by_max_length_layout.addWidget(self.split_by_max_length)
|
||||
split_by_max_length_layout.addWidget(self.split_by_max_length_input)
|
||||
|
||||
self.merge_options_row.addLayout(merge_by_gap_layout)
|
||||
self.merge_options_row.addLayout(split_by_punctuation_layout)
|
||||
self.merge_options_row.addLayout(split_by_max_length_layout)
|
||||
|
||||
self.merge_button = QPushButton(_("Merge"))
|
||||
self.merge_button.clicked.connect(self.on_merge_button_clicked)
|
||||
|
||||
self.merge_options_row.addWidget(self.merge_button)
|
||||
|
||||
merge_options_layout.addLayout(self.merge_options_row)
|
||||
|
||||
merge_options_group_box.setEnabled(self.transcription.word_level_timings == 1)
|
||||
if self.transcription.word_level_timings != 1:
|
||||
merge_options_group_box.setToolTip(_("Available only if word level timings were enabled during transcription"))
|
||||
|
||||
layout.addRow(merge_options_group_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def load_preferences(self):
|
||||
self.settings.settings.beginGroup("file_transcriber")
|
||||
preferences = FileTranscriptionPreferences.load(settings=self.settings.settings)
|
||||
self.settings.settings.endGroup()
|
||||
return preferences
|
||||
|
||||
def on_resize_button_clicked(self):
|
||||
segments = self.transcription_service.get_transcription_segments(
|
||||
transcription_id=self.transcription.id_as_uuid
|
||||
)
|
||||
|
||||
subs = []
|
||||
for segment in segments:
|
||||
subtitle = srt.Subtitle(
|
||||
index=segment.id,
|
||||
start=segment.start_time,
|
||||
end=segment.end_time,
|
||||
content=segment.text
|
||||
)
|
||||
subs.append(subtitle)
|
||||
|
||||
resized_subs = []
|
||||
last_index = 0
|
||||
|
||||
# Limit each subtitle to a maximum character length, splitting into
|
||||
# multiple subtitle items if necessary.
|
||||
for sub in subs:
|
||||
new_subs = srt_equalizer.split_subtitle(
|
||||
sub=sub, target_chars=self.target_chars_spin_box.value(), start_from_index=last_index, method="punctuation")
|
||||
last_index = new_subs[-1].index
|
||||
resized_subs.extend(new_subs)
|
||||
|
||||
segments = [
|
||||
Segment(
|
||||
round(sub.start),
|
||||
round(sub.end),
|
||||
sub.content
|
||||
)
|
||||
for sub in resized_subs
|
||||
if round(sub.start) != round(sub.end)
|
||||
]
|
||||
|
||||
new_transcript_id = self.transcription_service.copy_transcription(
|
||||
self.transcription.id_as_uuid
|
||||
)
|
||||
self.transcription_service.update_transcription_as_completed(new_transcript_id, segments)
|
||||
|
||||
if self.transcriptions_updated_signal:
|
||||
self.transcriptions_updated_signal.emit(new_transcript_id)
|
||||
|
||||
def on_extend_button_clicked(self):
|
||||
try:
|
||||
extend_amount_seconds = float(self.extend_amount_input.text())
|
||||
except ValueError:
|
||||
extend_amount_seconds = 0.2
|
||||
|
||||
# Convert seconds to milliseconds (internal time unit)
|
||||
extend_amount = int(extend_amount_seconds * 1000)
|
||||
|
||||
segments = self.transcription_service.get_transcription_segments(
|
||||
transcription_id=self.transcription.id_as_uuid
|
||||
)
|
||||
|
||||
extended_segments = []
|
||||
for i, segment in enumerate(segments):
|
||||
new_end = segment.end_time + extend_amount
|
||||
|
||||
# Ensure segment end doesn't exceed start of next segment
|
||||
if i < len(segments) - 1:
|
||||
next_start = segments[i + 1].start_time
|
||||
new_end = min(new_end, next_start)
|
||||
|
||||
extended_segments.append(
|
||||
Segment(
|
||||
start=segment.start_time,
|
||||
end=new_end,
|
||||
text=segment.text
|
||||
)
|
||||
)
|
||||
|
||||
new_transcript_id = self.transcription_service.copy_transcription(
|
||||
self.transcription.id_as_uuid
|
||||
)
|
||||
self.transcription_service.update_transcription_as_completed(new_transcript_id, extended_segments)
|
||||
|
||||
if self.transcriptions_updated_signal:
|
||||
self.transcriptions_updated_signal.emit(new_transcript_id)
|
||||
|
||||
def on_merge_button_clicked(self):
|
||||
self.new_transcript_id = self.transcription_service.copy_transcription(
|
||||
self.transcription.id_as_uuid
|
||||
)
|
||||
self.transcription_service.update_transcription_progress(self.new_transcript_id, 0.0)
|
||||
|
||||
if self.transcriptions_updated_signal:
|
||||
self.transcriptions_updated_signal.emit(self.new_transcript_id)
|
||||
|
||||
regroup_string = ''
|
||||
if self.merge_by_gap.isChecked():
|
||||
regroup_string += f'mg={self.merge_by_gap_input.text()}'
|
||||
|
||||
if self.split_by_max_length.isChecked():
|
||||
regroup_string += f'++{self.split_by_max_length_input.text()}+1'
|
||||
|
||||
if self.split_by_punctuation.isChecked():
|
||||
if regroup_string:
|
||||
regroup_string += '_'
|
||||
regroup_string += f'sp={self.split_by_punctuation_input.text()}'
|
||||
|
||||
if self.split_by_max_length.isChecked():
|
||||
if regroup_string:
|
||||
regroup_string += '_'
|
||||
regroup_string += f'sl={self.split_by_max_length_input.text()}'
|
||||
|
||||
regroup_string = os.getenv("BUZZ_MERGE_REGROUP_RULE", regroup_string)
|
||||
|
||||
self.hide()
|
||||
|
||||
self.thread = QThread()
|
||||
self.worker = TranscriptionWorker(
|
||||
self.transcription,
|
||||
self.transcription_options,
|
||||
self.transcription_service,
|
||||
regroup_string
|
||||
)
|
||||
self.worker.moveToThread(self.thread)
|
||||
self.thread.started.connect(self.worker.run)
|
||||
self.worker.finished.connect(self.thread.quit)
|
||||
self.worker.finished.connect(self.worker.deleteLater)
|
||||
self.thread.finished.connect(self.thread.deleteLater)
|
||||
self.worker.finished.connect(self.on_transcription_completed)
|
||||
|
||||
self.thread.start()
|
||||
|
||||
def on_transcription_completed(self, segments):
|
||||
if self.new_transcript_id is not None:
|
||||
self.transcription_service.update_transcription_as_completed(self.new_transcript_id, segments)
|
||||
|
||||
if self.transcriptions_updated_signal:
|
||||
self.transcriptions_updated_signal.emit(self.new_transcript_id)
|
||||
|
||||
self.close()
|
||||
|
||||
def closeEvent(self, event):
|
||||
self.hide()
|
||||
|
||||
super().closeEvent(event)
|
||||
|
|
@ -4,8 +4,7 @@ from dataclasses import dataclass
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex, QItemSelection, QEvent, QRegularExpression, QObject
|
||||
from PyQt6.QtGui import QRegularExpressionValidator
|
||||
from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex, QItemSelection
|
||||
from PyQt6.QtSql import QSqlTableModel, QSqlRecord
|
||||
from PyQt6.QtGui import QFontMetrics, QTextOption
|
||||
from PyQt6.QtWidgets import (
|
||||
|
|
@ -14,7 +13,6 @@ from PyQt6.QtWidgets import (
|
|||
QStyledItemDelegate,
|
||||
QAbstractItemView,
|
||||
QTextEdit,
|
||||
QLineEdit,
|
||||
)
|
||||
|
||||
from buzz.locale import _
|
||||
|
|
@ -39,188 +37,16 @@ class ColDef:
|
|||
delegate: Optional[QStyledItemDelegate] = None
|
||||
|
||||
|
||||
def parse_timestamp(timestamp_str: str) -> Optional[int]:
|
||||
"""Parse timestamp string (HH:MM:SS.mmm) to milliseconds"""
|
||||
try:
|
||||
# Handle formats like "00:01:23.456" or "1:23.456" or "23.456"
|
||||
parts = timestamp_str.strip().split(':')
|
||||
|
||||
if len(parts) == 3: # HH:MM:SS.mmm
|
||||
hours = int(parts[0])
|
||||
minutes = int(parts[1])
|
||||
seconds_parts = parts[2].split('.')
|
||||
elif len(parts) == 2: # MM:SS.mmm
|
||||
hours = 0
|
||||
minutes = int(parts[0])
|
||||
seconds_parts = parts[1].split('.')
|
||||
elif len(parts) == 1: # SS.mmm
|
||||
hours = 0
|
||||
minutes = 0
|
||||
seconds_parts = parts[0].split('.')
|
||||
else:
|
||||
return None
|
||||
|
||||
seconds = int(seconds_parts[0])
|
||||
milliseconds = int(seconds_parts[1]) if len(seconds_parts) > 1 else 0
|
||||
|
||||
total_ms = hours * 3600 * 1000 + minutes * 60 * 1000 + seconds * 1000 + milliseconds
|
||||
return total_ms
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
class TimeStampLineEdit(QLineEdit):
|
||||
"""Custom QLineEdit for timestamp editing with keyboard shortcuts"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._milliseconds = 0
|
||||
|
||||
# Set up validator to only allow digits, colons, and dots
|
||||
regex = QRegularExpression(r'^[0-9:.]*$')
|
||||
validator = QRegularExpressionValidator(regex, self)
|
||||
self.setValidator(validator)
|
||||
|
||||
def set_milliseconds(self, ms: int):
|
||||
self._milliseconds = ms
|
||||
self.setText(to_timestamp(ms))
|
||||
|
||||
def get_milliseconds(self) -> int:
|
||||
parsed = parse_timestamp(self.text())
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
return self._milliseconds
|
||||
|
||||
def keyPressEvent(self, event):
|
||||
if event.text() == '+':
|
||||
self._milliseconds += 500 # Add 500ms (0.5 seconds)
|
||||
self.setText(to_timestamp(self._milliseconds))
|
||||
event.accept()
|
||||
elif event.text() == '-':
|
||||
self._milliseconds = max(0, self._milliseconds - 500) # Subtract 500ms
|
||||
self.setText(to_timestamp(self._milliseconds))
|
||||
event.accept()
|
||||
else:
|
||||
super().keyPressEvent(event)
|
||||
|
||||
def focusOutEvent(self, event):
|
||||
# Strip any invalid characters and reformat on focus out
|
||||
parsed = parse_timestamp(self.text())
|
||||
if parsed is not None:
|
||||
self._milliseconds = parsed
|
||||
self.setText(to_timestamp(parsed))
|
||||
else:
|
||||
# If parsing failed, restore the last valid value
|
||||
self.setText(to_timestamp(self._milliseconds))
|
||||
super().focusOutEvent(event)
|
||||
|
||||
|
||||
class TimeStampDelegate(QStyledItemDelegate):
|
||||
def displayText(self, value, locale):
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class TimeStampEditorDelegate(QStyledItemDelegate):
|
||||
"""Delegate for editing timestamps with overlap prevention"""
|
||||
|
||||
timestamp_editing = pyqtSignal(int, int, int) # Signal: (row, column, new_value_ms)
|
||||
|
||||
def createEditor(self, parent, option, index):
|
||||
editor = TimeStampLineEdit(parent)
|
||||
# Connect text changed signal to emit live updates
|
||||
editor.textChanged.connect(lambda: self.on_editor_text_changed(editor, index))
|
||||
return editor
|
||||
|
||||
def on_editor_text_changed(self, editor, index):
|
||||
"""Emit signal when editor text changes with the current value"""
|
||||
new_value_ms = editor.get_milliseconds()
|
||||
self.timestamp_editing.emit(index.row(), index.column(), new_value_ms)
|
||||
|
||||
def setEditorData(self, editor, index):
|
||||
# Get value in milliseconds from database
|
||||
value = index.model().data(index, Qt.ItemDataRole.EditRole)
|
||||
if value is not None:
|
||||
editor.set_milliseconds(value)
|
||||
|
||||
def setModelData(self, editor, model, index):
|
||||
# Get value in milliseconds from editor
|
||||
new_value_ms = editor.get_milliseconds()
|
||||
current_row = index.row()
|
||||
column = index.column()
|
||||
|
||||
# Get current segment's start and end
|
||||
start_col = Column.START.value
|
||||
end_col = Column.END.value
|
||||
|
||||
if column == start_col:
|
||||
# Editing START time
|
||||
end_time_ms = model.record(current_row).value("end_time")
|
||||
|
||||
if end_time_ms is None:
|
||||
logging.warning("End time is None, cannot validate")
|
||||
return
|
||||
|
||||
# Validate: start must be less than end
|
||||
if new_value_ms >= end_time_ms:
|
||||
logging.warning(f"Start time ({new_value_ms}) must be less than end time ({end_time_ms})")
|
||||
return
|
||||
|
||||
# Check if new start overlaps with previous segment's end
|
||||
if current_row > 0:
|
||||
prev_end_time_ms = model.record(current_row - 1).value("end_time")
|
||||
if prev_end_time_ms is not None and new_value_ms < prev_end_time_ms:
|
||||
# Update previous segment's end to match new start
|
||||
model.setData(model.index(current_row - 1, end_col), new_value_ms)
|
||||
|
||||
elif column == end_col:
|
||||
# Editing END time
|
||||
start_time_ms = model.record(current_row).value("start_time")
|
||||
|
||||
if start_time_ms is None:
|
||||
logging.warning("Start time is None, cannot validate")
|
||||
return
|
||||
|
||||
# Validate: end must be greater than start
|
||||
if new_value_ms <= start_time_ms:
|
||||
logging.warning(f"End time ({new_value_ms}) must be greater than start time ({start_time_ms})")
|
||||
return
|
||||
|
||||
# Check if new end overlaps with next segment's start
|
||||
if current_row < model.rowCount() - 1:
|
||||
next_start_time_ms = model.record(current_row + 1).value("start_time")
|
||||
if next_start_time_ms is not None and new_value_ms > next_start_time_ms:
|
||||
# Update next segment's start to match new end
|
||||
model.setData(model.index(current_row + 1, start_col), new_value_ms)
|
||||
|
||||
# Set the new value
|
||||
model.setData(index, new_value_ms)
|
||||
|
||||
def displayText(self, value, locale):
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class CustomTextEdit(QTextEdit):
|
||||
"""Custom QTextEdit that handles Tab/Enter/Esc keys to save and close editor"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
|
||||
def keyPressEvent(self, event):
|
||||
# Tab, Enter, or Esc: save and close editor
|
||||
if event.key() in (Qt.Key.Key_Tab, Qt.Key.Key_Return, Qt.Key.Key_Enter, Qt.Key.Key_Escape):
|
||||
# Close the editor which will trigger setModelData to save
|
||||
self.clearFocus()
|
||||
event.accept()
|
||||
else:
|
||||
super().keyPressEvent(event)
|
||||
|
||||
|
||||
class WordWrapDelegate(QStyledItemDelegate):
|
||||
def createEditor(self, parent, option, index):
|
||||
editor = CustomTextEdit(parent)
|
||||
editor = QTextEdit(parent)
|
||||
editor.setWordWrapMode(QTextOption.WrapMode.WordWrap)
|
||||
editor.setAcceptRichText(False)
|
||||
editor.setTabChangesFocus(True)
|
||||
|
||||
return editor
|
||||
|
||||
|
|
@ -235,21 +61,16 @@ class TranscriptionSegmentModel(QSqlTableModel):
|
|||
self.setEditStrategy(QSqlTableModel.EditStrategy.OnFieldChange)
|
||||
self.setFilter(f"transcription_id = '{transcription_id}'")
|
||||
|
||||
def flags(self, index: QModelIndex):
|
||||
flags = super().flags(index)
|
||||
if index.column() in (Column.START.value, Column.END.value):
|
||||
flags &= ~Qt.ItemFlag.ItemIsEditable
|
||||
return flags
|
||||
|
||||
|
||||
class TranscriptionSegmentsEditorWidget(QTableView):
|
||||
PARENT_PADDINGS = 40
|
||||
segment_selected = pyqtSignal(QSqlRecord)
|
||||
timestamp_being_edited = pyqtSignal(int, int, int) # Signal: (row, column, new_value_ms)
|
||||
|
||||
def keyPressEvent(self, event):
|
||||
# Allow Enter/Return to trigger editing
|
||||
if event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter):
|
||||
current_index = self.currentIndex()
|
||||
if current_index.isValid() and not self.state() == QAbstractItemView.State.EditingState:
|
||||
self.edit(current_index)
|
||||
event.accept()
|
||||
return
|
||||
super().keyPressEvent(event)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -259,22 +80,18 @@ class TranscriptionSegmentsEditorWidget(QTableView):
|
|||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self._last_highlighted_row = -1
|
||||
self.translator = translator
|
||||
self.translator.translation.connect(self.update_translation)
|
||||
|
||||
model = TranscriptionSegmentModel(transcription_id=transcription_id)
|
||||
self.setModel(model)
|
||||
|
||||
timestamp_editor_delegate = TimeStampEditorDelegate()
|
||||
# Connect delegate's signal to widget's signal
|
||||
timestamp_editor_delegate.timestamp_editing.connect(self.timestamp_being_edited.emit)
|
||||
|
||||
timestamp_delegate = TimeStampDelegate()
|
||||
word_wrap_delegate = WordWrapDelegate()
|
||||
|
||||
self.column_definitions: list[ColDef] = [
|
||||
ColDef("start", _("Start"), Column.START, delegate=timestamp_editor_delegate),
|
||||
ColDef("end", _("End"), Column.END, delegate=timestamp_editor_delegate),
|
||||
ColDef("start", _("Start"), Column.START, delegate=timestamp_delegate),
|
||||
ColDef("end", _("End"), Column.END, delegate=timestamp_delegate),
|
||||
ColDef("text", _("Text"), Column.TEXT, delegate=word_wrap_delegate),
|
||||
ColDef("translation", _("Translation"), Column.TRANSLATION, delegate=word_wrap_delegate),
|
||||
]
|
||||
|
|
@ -298,10 +115,6 @@ class TranscriptionSegmentsEditorWidget(QTableView):
|
|||
self.verticalHeader().hide()
|
||||
self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
|
||||
self.setSelectionMode(QTableView.SelectionMode.SingleSelection)
|
||||
self.setEditTriggers(
|
||||
QAbstractItemView.EditTrigger.EditKeyPressed |
|
||||
QAbstractItemView.EditTrigger.DoubleClicked
|
||||
)
|
||||
self.selectionModel().selectionChanged.connect(self.on_selection_changed)
|
||||
model.select()
|
||||
model.rowsInserted.connect(self.init_row_height)
|
||||
|
|
@ -313,8 +126,8 @@ class TranscriptionSegmentsEditorWidget(QTableView):
|
|||
|
||||
self.init_row_height()
|
||||
|
||||
self.setColumnWidth(Column.START.value, 120)
|
||||
self.setColumnWidth(Column.END.value, 120)
|
||||
self.setColumnWidth(Column.START.value, 95)
|
||||
self.setColumnWidth(Column.END.value, 95)
|
||||
|
||||
self.setWordWrap(True)
|
||||
|
||||
|
|
@ -369,17 +182,3 @@ class TranscriptionSegmentsEditorWidget(QTableView):
|
|||
|
||||
def segments(self) -> list[QSqlRecord]:
|
||||
return [self.model().record(i) for i in range(self.model().rowCount())]
|
||||
|
||||
def highlight_and_scroll_to_row(self, row_index: int):
|
||||
"""Highlight a specific row and scroll it into view"""
|
||||
if 0 <= row_index < self.model().rowCount():
|
||||
# Only set focus if we're actually moving to a different row to avoid audio crackling
|
||||
if self._last_highlighted_row != row_index:
|
||||
self.setFocus()
|
||||
self._last_highlighted_row = row_index
|
||||
|
||||
# Select the row
|
||||
self.selectRow(row_index)
|
||||
# Scroll to the row with better positioning
|
||||
model_index = self.model().index(row_index, 0)
|
||||
self.scrollTo(model_index, QAbstractItemView.ScrollHint.PositionAtCenter)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -21,22 +20,13 @@ class ViewMode(Enum):
|
|||
class TranscriptionViewModeToolButton(QToolButton):
|
||||
view_mode_changed = pyqtSignal(ViewMode)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shortcuts: Shortcuts,
|
||||
has_translation: bool,
|
||||
translation: pyqtSignal,
|
||||
parent: Optional[QWidget] = None
|
||||
):
|
||||
def __init__(self, shortcuts: Shortcuts, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.setText(_("View"))
|
||||
self.setIcon(VisibilityIcon(self))
|
||||
self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextBesideIcon)
|
||||
self.setPopupMode(QToolButton.ToolButtonPopupMode.MenuButtonPopup)
|
||||
self.setMinimumWidth(80)
|
||||
|
||||
translation.connect(self.on_translation_available)
|
||||
self.setPopupMode(QToolButton.ToolButtonPopupMode.InstantPopup)
|
||||
|
||||
menu = QMenu(self)
|
||||
|
||||
|
|
@ -46,12 +36,11 @@ class TranscriptionViewModeToolButton(QToolButton):
|
|||
lambda: self.view_mode_changed.emit(ViewMode.TEXT),
|
||||
)
|
||||
|
||||
self.translation_action = menu.addAction(
|
||||
menu.addAction(
|
||||
_("Translation"),
|
||||
QKeySequence(shortcuts.get(Shortcut.VIEW_TRANSCRIPT_TRANSLATION)),
|
||||
lambda: self.view_mode_changed.emit(ViewMode.TRANSLATION)
|
||||
)
|
||||
self.translation_action.setVisible(has_translation)
|
||||
|
||||
menu.addAction(
|
||||
_("Timestamps"),
|
||||
|
|
@ -60,7 +49,3 @@ class TranscriptionViewModeToolButton(QToolButton):
|
|||
)
|
||||
|
||||
self.setMenu(menu)
|
||||
self.clicked.connect(self.showMenu)
|
||||
|
||||
def on_translation_available(self):
|
||||
self.translation_action.setVisible(True)
|
||||
|
|
|
|||